diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 4406e58273..985b12df91 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -24,7 +24,7 @@ from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger(__name__)
-class Databases(object):
+class Databases:
"""The various databases.
These are low level interfaces to physical databases.
@@ -47,9 +47,14 @@ class Databases(object):
engine = create_engine(database_config.config)
with make_conn(database_config, engine) as db_conn:
- logger.info("Preparing database %r...", db_name)
-
+ logger.info("[database config %r]: Checking database server", db_name)
engine.check_database(db_conn)
+
+ logger.info(
+ "[database config %r]: Preparing for databases %r",
+ db_name,
+ database_config.databases,
+ )
prepare_database(
db_conn, engine, hs.config, databases=database_config.databases,
)
@@ -57,7 +62,9 @@ class Databases(object):
database = DatabasePool(hs, database_config, engine)
if "main" in database_config.databases:
- logger.info("Starting 'main' data store")
+ logger.info(
+ "[database config %r]: Starting 'main' database", db_name
+ )
# Sanity check we don't try and configure the main store on
# multiple databases.
@@ -72,7 +79,9 @@ class Databases(object):
persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases:
- logger.info("Starting 'state' data store")
+ logger.info(
+ "[database config %r]: Starting 'state' database", db_name
+ )
# Sanity check we don't try and configure the state store on
# multiple databases.
@@ -85,14 +94,23 @@ class Databases(object):
self.databases.append(database)
- logger.info("Database %r prepared", db_name)
+ logger.info("[database config %r]: prepared", db_name)
+
+ # Closing the context manager doesn't close the connection.
+ # psycopg will close the connection when the object gets GCed, but *only*
+ # if the PID is the same as when the connection was opened [1], and
+ # it may not be if we fork in the meantime.
+ #
+ # [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378
+
+ db_conn.close()
# Sanity check that we have actually configured all the required stores.
if not main:
- raise Exception("No 'main' data store configured")
+ raise Exception("No 'main' database configured")
if not state:
- raise Exception("No 'main' data store configured")
+ raise Exception("No 'state' database configured")
# We use local variables here to ensure that the databases do not have
# optional types.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 17fa470919..2ae2fbd5d7 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,6 +18,7 @@
import calendar
import logging
import time
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
@@ -28,6 +29,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
+from synapse.types import get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
@@ -263,6 +265,9 @@ class DataStore(
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
@@ -290,16 +295,16 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
- def count_daily_users(self):
+ async def count_daily_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_daily_users", self._count_users, yesterday
)
- def count_monthly_users(self):
+ async def count_monthly_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 30 days.
Note this method is intended for phonehome metrics only and is different
@@ -307,7 +312,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)
@@ -326,15 +331,15 @@ class DataStore(
(count,) = txn.fetchone()
return count
- def count_r30_users(self):
+ async def count_r30_users(self) -> Dict[str, int]:
"""
Counts the number of 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days ago
* Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days apart
- Returns counts globaly for a given user as well as breaking
- by platform
+ Returns:
+ A mapping of counts globally as well as broken out by platform.
"""
def _count_r30_users(txn):
@@ -407,7 +412,7 @@ class DataStore(
return results
- return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+ return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self):
"""
@@ -417,7 +422,7 @@ class DataStore(
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000
- def generate_user_daily_visits(self):
+ async def generate_user_daily_visits(self) -> None:
"""
Generates daily visit data for use in cohort/ retention analysis
"""
@@ -472,18 +477,17 @@ class DataStore(
# frequently
self._last_user_visit_update = now
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
- def get_users(self):
+ async def get_users(self) -> List[Dict[str, Any]]:
"""Function to retrieve a list of users in users table.
- Args:
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ A list of dictionaries representing users.
"""
- return self.db_pool.simple_select_list(
+ return await self.db_pool.simple_select_list(
table="users",
keyvalues={},
retcols=[
@@ -497,30 +501,40 @@ class DataStore(
desc="get_users",
)
- def get_users_paginate(
- self, start, limit, name=None, guests=True, deactivated=False
- ):
+ async def get_users_paginate(
+ self,
+ start: int,
+ limit: int,
+ user_id: Optional[str] = None,
+ name: Optional[str] = None,
+ guests: bool = True,
+ deactivated: bool = False,
+ ) -> Tuple[List[Dict[str, Any]], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
Args:
- start (int): start number to begin the query from
- limit (int): number of rows to retrieve
- name (string): filter for user names
- guests (bool): whether to in include guest users
- deactivated (bool): whether to include deactivated users
+ start: start number to begin the query from
+ limit: number of rows to retrieve
+ user_id: search for user_id. ignored if name is not None
+ name: search for local part of user_id or display name
+ guests: whether to in include guest users
+ deactivated: whether to include deactivated users
Returns:
- defer.Deferred: resolves to list[dict[str, Any]], int
+ A tuple of a list of mappings from user to information and a count of total users.
"""
def get_users_paginate_txn(txn):
filters = []
- args = []
+ args = [self.hs.config.server_name]
if name:
+ filters.append("(name LIKE ? OR displayname LIKE ?)")
+ args.extend(["@%" + name + "%:%", "%" + name + "%"])
+ elif user_id:
filters.append("name LIKE ?")
- args.append("%" + name + "%")
+ args.extend(["%" + user_id + "%"])
if not guests:
filters.append("is_guest = 0")
@@ -530,39 +544,42 @@ class DataStore(
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
- sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
- txn.execute(sql, args)
- count = txn.fetchone()[0]
-
- args = [self.hs.config.server_name] + args + [limit, start]
- sql = """
- SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+ sql_base = """
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
{}
- ORDER BY u.name LIMIT ? OFFSET ?
""".format(
where_clause
)
+ sql = "SELECT COUNT(*) as total_users " + sql_base
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = (
+ "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+ + sql_base
+ + " ORDER BY u.name LIMIT ? OFFSET ?"
+ )
+ args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
return users, count
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_paginate_txn", get_users_paginate_txn
)
- def search_users(self, term):
+ async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
"""Function to search users list for one or more users with
the matched term.
Args:
- term (str): search term
- col (str): column to query term should be matched to
+ term: search term
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ A list of dictionaries or None.
"""
- return self.db_pool.simple_search_list(
+ return await self.db_pool.simple_search_list(
table="users",
term=term,
col="name",
@@ -575,21 +592,24 @@ def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig
"""Called before upgrading an existing database to check that it is broadly sane
compared with the configuration.
"""
- domain = config.server_name
+ logger.info("Checking database for consistency with configuration...")
- sql = database_engine.convert_param_style(
- "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
- )
- pat = "%:" + domain
- cur.execute(sql, (pat,))
- num_not_matching = cur.fetchall()[0][0]
- if num_not_matching == 0:
+ # if there are any users in the database, check that the username matches our
+ # configured server name.
+
+ cur.execute("SELECT name FROM users LIMIT 1")
+ rows = cur.fetchall()
+ if not rows:
+ return
+
+ user_domain = get_domain_from_id(rows[0][0])
+ if user_domain == config.server_name:
return
raise Exception(
"Found users in database not native to %s!\n"
- "You cannot changed a synapse server_name after it's been configured"
- % (domain,)
+ "You cannot change a synapse server_name after it's been configured"
+ % (config.server_name,)
)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..4436b1a83d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,9 +16,7 @@
import abc
import logging
-from typing import List, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Dict, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
@@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cached()
- def get_account_data_for_user(self, user_id):
+ async def get_account_data_for_user(
+ self, user_id: str
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user.
Args:
- user_id(str): The user to get the account_data for.
+ user_id: The user to get the account_data for.
Returns:
- A deferred pair of a dict of global account_data and a dict
- mapping from room_id string to per room account_data dicts.
+ A 2-tuple of a dict of global account_data and a dict mapping from
+ room_id string to per room account_data dicts.
"""
def get_account_data_for_user_txn(txn):
@@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
return None
@cached(num_args=2)
- def get_account_data_for_room(self, user_id, room_id):
+ async def get_account_data_for_room(
+ self, user_id: str, room_id: str
+ ) -> Dict[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
- user_id(str): The user to get the account_data for.
- room_id(str): The room to get the account_data for.
+ user_id: The user to get the account_data for.
+ room_id: The room to get the account_data for.
Returns:
- A deferred dict of the room account_data
+ A dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
@@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@cached(num_args=3, max_entries=5000)
- def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+ async def get_account_data_for_room_and_type(
+ self, user_id: str, room_id: str, account_data_type: str
+ ) -> Optional[JsonDict]:
"""Get the client account_data of given type for a user for a room.
Args:
- user_id(str): The user to get the account_data for.
- room_id(str): The room to get the account_data for.
- account_data_type (str): The account data type to get.
+ user_id: The user to get the account_data for.
+ room_id: The room to get the account_data for.
+ account_data_type: The account data type to get.
Returns:
- A deferred of the room account_data for that type, or None if
- there isn't any set.
+ The room account_data for that type, or None if there isn't any set.
"""
def get_account_data_for_room_and_type_txn(txn):
@@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_updated_room_account_data", get_updated_room_account_data_txn
)
- def get_updated_account_data_for_user(self, user_id, stream_id):
+ async def get_updated_account_data_for_user(
+ self, user_id: str, stream_id: int
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a that's changed for a user
Args:
- user_id(str): The user to get the account_data for.
- stream_id(int): The point in the stream since which to get updates
+ user_id: The user to get the account_data for.
+ stream_id: The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
@@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
- return defer.succeed(({}, {}))
+ return ({}, {})
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@@ -336,7 +341,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -384,7 +389,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
@@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
- def _update_max_stream_id(self, next_id: int):
+ async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id
Args:
@@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
- return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+ await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5cf1a88399..454c0bc50c 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -16,13 +16,12 @@
import logging
import re
-from canonicaljson import json
-
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import json_encoder
logger = logging.getLogger(__name__)
@@ -162,20 +161,18 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state")
return None
- def set_appservice_state(self, service, state):
+ async def set_appservice_state(self, service, state) -> None:
"""Set the application service state.
Args:
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
- Returns:
- A Deferred which resolves when the state was set successfully.
"""
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
)
- def create_appservice_txn(self, service, events):
+ async def create_appservice_txn(self, service, events):
"""Atomically creates a new transaction for this application service
with the given list of events.
@@ -204,7 +201,7 @@ class ApplicationServiceTransactionWorkerStore(
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
- event_ids = json.dumps([e.event_id for e in events])
+ event_ids = json_encoder.encode([e.event_id for e in events])
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
@@ -212,20 +209,17 @@ class ApplicationServiceTransactionWorkerStore(
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn
)
- def complete_appservice_txn(self, txn_id, service):
+ async def complete_appservice_txn(self, txn_id, service) -> None:
"""Completes an application service transaction.
Args:
txn_id(str): The transaction ID being completed.
service(ApplicationService): The application service which was sent
this transaction.
- Returns:
- A Deferred which resolves if this transaction was stored
- successfully.
"""
txn_id = int(txn_id)
@@ -261,7 +255,7 @@ class ApplicationServiceTransactionWorkerStore(
{"txn_id": txn_id, "as_id": service.id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"complete_appservice_txn", _complete_appservice_txn
)
@@ -315,13 +309,13 @@ class ApplicationServiceTransactionWorkerStore(
else:
return int(last_txn_id[0]) # select 'last_txn' col
- def set_appservice_last_pos(self, pos):
+ async def set_appservice_last_pos(self, pos) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 10de446065..1e7637a6f5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
},
)
- def get_cache_stream_token(self, instance_name):
+ def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
- return self._cache_id_gen.get_current_token(instance_name)
+ return self._cache_id_gen.get_current_token_for_writer(instance_name)
else:
return 0
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 216a5925fc..c2fc847fbc 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
@wrap_as_background_process("update_client_ips")
- def _update_client_ips_batch(self):
+ async def _update_client_ips_batch(self) -> None:
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
to_update = self._batch_row_update
self._batch_row_update = {}
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..0044433110 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -190,15 +190,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
@trace
- def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
+ async def delete_device_msgs_for_remote(
+ self, destination: str, up_to_stream_id: int
+ ) -> None:
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
- destination(str): The destination server_name
- up_to_stream_id(int): Where to delete messages up to.
- Returns:
- A deferred that resolves when the messages have been deleted.
+ destination: The destination server_name
+ up_to_stream_id: Where to delete messages up to.
"""
def delete_messages_for_remote_destination_txn(txn):
@@ -209,7 +209,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
txn.execute(sql, (destination, up_to_stream_id))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2b33060480..add4e3ea0e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -14,8 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -47,7 +48,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def get_device(self, user_id: str, device_id: str):
+ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -55,11 +56,11 @@ class DeviceWorkerStore(SQLBaseStore):
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
- defer.Deferred for a dict containing the device information
+ A dict containing the device information
Raises:
StoreError: if the device is not found
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
"""
- now_stream_id = self._device_list_id_gen.get_current_token()
+ now_stream_id = self.get_device_stream_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -254,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
List of objects representing an device update EDU
"""
devices = (
- await self.db_pool.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
+ await self.get_e2e_device_keys_and_signatures(
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
@@ -292,17 +291,11 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = stream_id
if device is not None:
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
-
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
+ keys = device.keys
+ if keys:
+ result["keys"] = keys
- device_display_name = device.get("device_display_name", None)
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
@@ -312,9 +305,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results
- def _get_last_device_update_for_remote_user(
+ async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int
- ):
+ ) -> int:
def f(txn):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
@@ -325,12 +318,16 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
- return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
+ return await self.db_pool.runInteraction(
+ "get_last_device_update_for_remote_user", f
+ )
- def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
+ async def mark_as_sent_devices_by_remote(
+ self, destination: str, stream_id: int
+ ) -> None:
"""Mark that updates have successfully been sent to the destination.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
@@ -380,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ with await self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -412,8 +409,10 @@ class DeviceWorkerStore(SQLBaseStore):
},
)
+ @abc.abstractmethod
def get_device_stream_token(self) -> int:
- return self._device_list_id_gen.get_current_token()
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
@trace
async def get_user_devices_from_cache(
@@ -481,53 +480,6 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
- def get_devices_with_keys_by_user(self, user_id: str):
- """Get all devices (with any device keys) for a user
-
- Returns:
- Deferred which resolves to (stream_id, devices)
- """
- return self.db_pool.runInteraction(
- "get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn,
- user_id,
- )
-
- def _get_devices_with_keys_by_user_txn(
- self, txn: LoggingTransaction, user_id: str
- ) -> Tuple[int, List[JsonDict]]:
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- devices = self._get_e2e_device_keys_txn(
- txn, [(user_id, None)], include_all_devices=True
- )
-
- if devices:
- user_devices = devices[user_id]
- results = []
- for device_id, device in user_devices.items():
- result = {"device_id": device_id}
-
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
-
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
-
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
-
- results.append(result)
-
- return now_stream_id, results
-
- return now_stream_id, []
-
async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
) -> Set[str]:
@@ -656,11 +608,13 @@ class DeviceWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
- def get_device_list_last_stream_id_for_remote(self, user_id: str):
+ async def get_device_list_last_stream_id_for_remote(
+ self, user_id: str
+ ) -> Optional[Any]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@@ -671,10 +625,9 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
- inlineCallbacks=True,
)
- def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+ rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@@ -715,11 +668,11 @@ class DeviceWorkerStore(SQLBaseStore):
return {row["user_id"] for row in rows}
- def mark_remote_user_device_cache_as_stale(self, user_id: str):
+ async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
"""Records that the server has reason to believe the cache of the devices
for the remote users is out of date.
"""
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="device_lists_remote_resync",
keyvalues={"user_id": user_id},
values={},
@@ -727,7 +680,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
- def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
+ async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user.
"""
@@ -741,7 +694,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn,
)
@@ -1002,9 +955,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
desc="update_device",
)
- def update_remote_device_list_cache_entry(
+ async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
- ):
+ ) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's
@@ -1015,11 +968,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: ID of decivice being updated
content: new data on this device
stream_id: the version of the device list
-
- Returns:
- Deferred[None]
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@@ -1071,9 +1021,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
lock=False,
)
- def update_remote_device_list_cache(
+ async def update_remote_device_list_cache(
self, user_id: str, devices: List[dict], stream_id: int
- ):
+ ) -> None:
"""Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's
@@ -1083,11 +1033,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: User to update device list for
devices: list of device objects supplied over federation
stream_id: the version of the device list
-
- Returns:
- Deferred[None]
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@@ -1097,7 +1044,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
- ):
+ ) -> None:
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
@@ -1147,7 +1094,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+ with await self._device_list_id_gen.get_next_mult(
+ len(device_ids)
+ ) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
@@ -1160,7 +1109,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with self._device_list_id_gen.get_next_mult(
+ with await self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 037e02603c..e5060d4c46 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -14,7 +14,7 @@
# limitations under the License.
from collections import namedtuple
-from typing import Iterable, Optional
+from typing import Iterable, List, Optional
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
- def get_room_alias_creator(self, room_alias):
- return self.db_pool.simple_select_one_onecol(
+ async def get_room_alias_creator(self, room_alias: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@@ -68,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore):
)
@cached(max_entries=5000)
- def get_aliases_for_room(self, room_id):
- return self.db_pool.simple_select_onecol(
+ async def get_aliases_for_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
@@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
- def update_aliases_for_room(
+ async def update_aliases_for_room(
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
- ):
+ ) -> None:
"""Repoint all of the aliases for a given room, to a different room.
Args:
@@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 2eeb9f97dc..12cecceec2 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -149,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return sessions
- def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+ async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
"""Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for
@@ -164,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
that we want to query
Returns:
- Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
+ dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
@@ -223,15 +225,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return ret
- def count_e2e_room_keys(self, user_id, version):
+ async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup we're querying about
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup we're querying about
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
@@ -281,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
raise StoreError(404, "No current backup version")
return row[0]
- def get_e2e_room_keys_version_info(self, user_id, version=None):
+ async def get_e2e_room_keys_version_info(self, user_id, version=None):
"""Get info metadata about a version of our room_keys backup.
Args:
@@ -291,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present
Returns:
- A deferred dict giving the info metadata for this backup version, with
+ A dict giving the info metadata for this backup version, with
fields including:
version(str)
algorithm(str)
@@ -322,12 +324,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result["etag"] = 0
return result
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
@trace
- def create_e2e_room_keys_version(self, user_id, info):
+ async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
@@ -336,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
info(dict): the info about the backup version to be created
Returns:
- A deferred string for the newly created version ID
+ The newly created version ID
"""
def _create_e2e_room_keys_version_txn(txn):
@@ -363,23 +365,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return new_version
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
@trace
- def update_e2e_room_keys_version(
- self, user_id, version, info=None, version_etag=None
- ):
+ async def update_e2e_room_keys_version(
+ self,
+ user_id: str,
+ version: str,
+ info: Optional[dict] = None,
+ version_etag: Optional[int] = None,
+ ) -> None:
"""Update a given backup version
Args:
- user_id(str): the user whose backup version we're updating
- version(str): the version ID of the backup version we're updating
- info (dict): the new backup version info to store. If None, then
- the backup version info is not updated
- version_etag (Optional[int]): etag of the keys in the backup. If
- None, then the etag is not updated
+ user_id: the user whose backup version we're updating
+ version: the version ID of the backup version we're updating
+ info: the new backup version info to store. If None, then the backup
+ version info is not updated.
+ version_etag: etag of the keys in the backup. If None, then the etag
+ is not updated.
"""
updatevalues = {}
@@ -389,7 +395,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues["etag"] = version_etag
if updatevalues:
- return self.db_pool.simple_update(
+ await self.db_pool.simple_update(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues,
@@ -397,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- def delete_e2e_room_keys_version(self, user_id, version=None):
+ async def delete_e2e_room_keys_version(
+ self, user_id: str, version: Optional[str] = None
+ ) -> None:
"""Delete a given backup version of the user's room keys.
Doesn't delete their actual key data.
Args:
- user_id(str): the user whose backup version we're deleting
- version(str): Optional. the version ID of the backup version we're deleting
+ user_id: the user whose backup version we're deleting
+ version: Optional. the version ID of the backup version we're deleting
If missing, we delete the current backup version info.
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present,
@@ -424,13 +432,13 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues={"user_id": user_id, "version": this_version},
)
- return self.db_pool.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..fba3098ea2 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,8 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Iterable, List, Optional, Tuple
+import abc
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+import attr
from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
@@ -23,24 +25,68 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.types import Cursor
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
+if TYPE_CHECKING:
+ from synapse.handlers.e2e_keys import SignatureListItem
+
+
+@attr.s
+class DeviceKeyLookupResult:
+ """The type returned by get_e2e_device_keys_and_signatures"""
+
+ display_name = attr.ib(type=Optional[str])
+
+ # the key data from e2e_device_keys_json. Typically includes fields like
+ # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
+ # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
+ keys = attr.ib(type=Optional[JsonDict])
+
class EndToEndKeyWorkerStore(SQLBaseStore):
+ async def get_e2e_device_keys_for_federation_query(
+ self, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
+ """Get all devices (with any device keys) for a user
+
+ Returns:
+ (stream_id, devices)
+ """
+ now_stream_id = self.get_device_stream_token()
+
+ devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
+
+ if devices:
+ user_devices = devices[user_id]
+ results = []
+ for device_id, device in user_devices.items():
+ result = {"device_id": device_id}
+
+ keys = device.keys
+ if keys:
+ result["keys"] = keys
+
+ device_display_name = device.display_name
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
@trace
- async def get_e2e_device_keys(
- self, query_list, include_all_devices=False, include_deleted_devices=False
- ):
- """Fetch a list of device keys.
+ async def get_e2e_device_keys_for_cs_api(
+ self, query_list: List[Tuple[str, Optional[str]]]
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ """Fetch a list of device keys, formatted suitably for the C/S API.
Args:
query_list(list): List of pairs of user_ids and device_ids.
- include_all_devices (bool): whether to include entries for devices
- that don't have device keys
- include_deleted_devices (bool): whether to include null entries for
- devices which no longer exist (but were in the query_list).
- This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data. The key data will be a dict in the same format as the
@@ -50,13 +96,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
- results = await self.db_pool.runInteraction(
- "get_e2e_device_keys",
- self._get_e2e_device_keys_txn,
- query_list,
- include_all_devices,
- include_deleted_devices,
- )
+ results = await self.get_e2e_device_keys_and_signatures(query_list)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
@@ -64,31 +104,95 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
- r = db_to_json(device_info.pop("key_json"))
+ r = device_info.keys
r["unsigned"] = {}
- display_name = device_info["device_display_name"]
+ display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
- if "signatures" in device_info:
- for sig_user_id, sigs in device_info["signatures"].items():
- r.setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
rv[user_id][device_id] = r
return rv
@trace
- def _get_e2e_device_keys_txn(
- self, txn, query_list, include_all_devices=False, include_deleted_devices=False
- ):
+ async def get_e2e_device_keys_and_signatures(
+ self,
+ query_list: List[Tuple[str, Optional[str]]],
+ include_all_devices: bool = False,
+ include_deleted_devices: bool = False,
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ """Fetch a list of device keys
+
+ Any cross-signatures made on the keys by the owner of the device are also
+ included.
+
+ The cross-signatures are added to the `signatures` field within the `keys`
+ object in the response.
+
+ Args:
+ query_list: List of pairs of user_ids and device_ids. Device id can be None
+ to indicate "all devices for this user"
+
+ include_all_devices: whether to return devices without device keys
+
+ include_deleted_devices: whether to include null entries for
+ devices which no longer exist (but were in the query_list).
+ This option only takes effect if include_all_devices is true.
+
+ Returns:
+ Dict mapping from user-id to dict mapping from device_id to
+ key data.
+ """
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
+ result = await self.db_pool.runInteraction(
+ "get_e2e_device_keys",
+ self._get_e2e_device_keys_txn,
+ query_list,
+ include_all_devices,
+ include_deleted_devices,
+ )
+
+ # get the (user_id, device_id) tuples to look up cross-signatures for
+ signature_query = (
+ (user_id, device_id)
+ for user_id, dev in result.items()
+ for device_id, d in dev.items()
+ if d is not None and d.keys is not None
+ )
+
+ for batch in batch_iter(signature_query, 50):
+ cross_sigs_result = await self.db_pool.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_for_devices_txn,
+ batch,
+ )
+
+ # add each cross-signing signature to the correct device in the result dict.
+ for (user_id, key_id, device_id, signature) in cross_sigs_result:
+ target_device_result = result[user_id][device_id]
+ target_device_signatures = target_device_result.keys.setdefault(
+ "signatures", {}
+ )
+ signing_user_signatures = target_device_signatures.setdefault(
+ user_id, {}
+ )
+ signing_user_signatures[key_id] = signature
+
+ log_kv(result)
+ return result
+
+ def _get_e2e_device_keys_txn(
+ self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ """Get information on devices from the database
+
+ The results include the device's keys and self-signatures, but *not* any
+ cross-signing signatures which have been added subsequently (for which, see
+ get_e2e_device_keys_and_signatures)
+ """
query_clauses = []
query_params = []
- signature_query_clauses = []
- signature_query_params = []
if include_all_devices is False:
include_deleted_devices = False
@@ -99,24 +203,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
- signature_query_clause = "target_user_id = ?"
- signature_query_params.append(user_id)
if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
- signature_query_clause += " AND target_device_id = ?"
- signature_query_params.append(device_id)
-
- signature_query_clause += " AND user_id = ?"
- signature_query_params.append(user_id)
query_clauses.append(query_clause)
- signature_query_clauses.append(signature_query_clause)
sql = (
"SELECT user_id, device_id, "
- " d.display_name AS device_display_name, "
+ " d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@@ -127,51 +223,49 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
- rows = self.db_pool.cursor_to_dict(txn)
- result = {}
- for row in rows:
+ result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+ for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
- deleted_devices.remove((row["user_id"], row["device_id"]))
- result.setdefault(row["user_id"], {})[row["device_id"]] = row
+ deleted_devices.remove((user_id, device_id))
+ result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+ display_name, db_to_json(key_json) if key_json else None
+ )
if include_deleted_devices:
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
- # get signatures on the device
- signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
- " OR ".join("(" + q + ")" for q in signature_query_clauses)
- )
+ return result
- txn.execute(signature_sql, signature_query_params)
- rows = self.db_pool.cursor_to_dict(txn)
-
- # add each cross-signing signature to the correct device in the result dict.
- for row in rows:
- signing_user_id = row["user_id"]
- signing_key_id = row["key_id"]
- target_user_id = row["target_user_id"]
- target_device_id = row["target_device_id"]
- signature = row["signature"]
-
- target_user_result = result.get(target_user_id)
- if not target_user_result:
- continue
+ def _get_e2e_cross_signing_signatures_for_devices_txn(
+ self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+ ) -> List[Tuple[str, str, str, str]]:
+ """Get cross-signing signatures for a given list of devices
- target_device_result = target_user_result.get(target_device_id)
- if not target_device_result:
- # note that target_device_result will be None for deleted devices.
- continue
+ Returns signatures made by the owners of the devices.
- target_device_signatures = target_device_result.setdefault("signatures", {})
- signing_user_signatures = target_device_signatures.setdefault(
- signing_user_id, {}
+ Returns: a list of results; each entry in the list is a tuple of
+ (user_id, key_id, target_device_id, signature).
+ """
+ signature_query_clauses = []
+ signature_query_params = []
+
+ for (user_id, device_id) in device_query:
+ signature_query_clauses.append(
+ "target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
- signing_user_signatures[signing_key_id] = signature
+ signature_query_params.extend([user_id, device_id, user_id])
- log_kv(result)
- return result
+ signature_sql = """
+ SELECT user_id, key_id, target_device_id, signature
+ FROM e2e_cross_signing_signatures WHERE %s
+ """ % (
+ " OR ".join("(" + q + ")" for q in signature_query_clauses)
+ )
+
+ txn.execute(signature_sql, signature_query_params)
+ return txn.fetchall()
async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]
@@ -249,10 +343,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
- def count_e2e_one_time_keys(self, user_id, device_id):
+ async def count_e2e_one_time_keys(
+ self, user_id: str, device_id: str
+ ) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device
Returns:
- Dict mapping from algorithm to number of keys for that algorithm.
+ A mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
@@ -267,7 +363,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
@@ -305,7 +401,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
list_name="user_ids",
num_args=1,
)
- def _get_bare_e2e_cross_signing_keys_bulk(
+ async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
@@ -313,16 +409,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the signatures for the calling user need to be fetched.
Args:
- user_ids (list[str]): the users whose keys are being requested
+ user_ids: the users whose keys are being requested
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. If a user's cross-signing keys were not found, either
- their user ID will not be in the dict, or their user ID will map
- to None.
+ A mapping from user ID to key type to key data. If a user's cross-signing
+ keys were not found, either their user ID will not be in the dict, or
+ their user ID will map to None.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
@@ -538,9 +633,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
_get_all_user_signature_changes_for_remotes_txn,
)
+ @abc.abstractmethod
+ def get_device_stream_token(self) -> int:
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
- def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+ async def set_e2e_device_keys(
+ self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+ ) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
@@ -576,12 +678,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
- def claim_e2e_one_time_keys(self, query_list):
- """Take a list of one time keys out of the database"""
+ async def claim_e2e_one_time_keys(
+ self, query_list: Iterable[Tuple[str, str, str]]
+ ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+ """Take a list of one time keys out of the database.
+
+ Args:
+ query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+ """
@trace
def _claim_e2e_one_time_keys(txn):
@@ -617,11 +728,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
- def delete_e2e_keys_by_device(self, user_id, device_id):
+ async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
@@ -644,11 +755,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
"""Set a user's cross-signing key.
Args:
@@ -658,6 +769,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
key (dict): the key data
+ stream_id (int)
"""
# the 'key' dict will look something like:
# {
@@ -695,23 +807,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
# and finally, store the key itself
- with self._cross_signing_id_gen.get_next() as stream_id:
- self.db_pool.simple_insert_txn(
- txn,
- "e2e_cross_signing_keys",
- values={
- "user_id": user_id,
- "keytype": key_type,
- "keydata": json_encoder.encode(key),
- "stream_id": stream_id,
- },
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ "e2e_cross_signing_keys",
+ values={
+ "user_id": user_id,
+ "keytype": key_type,
+ "keydata": json_encoder.encode(key),
+ "stream_id": stream_id,
+ },
+ )
self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
- def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ async def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
Args:
@@ -719,22 +830,27 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
- return self.db_pool.runInteraction(
- "add_e2e_cross_signing_key",
- self._set_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- key,
- )
- def store_e2e_cross_signing_signatures(self, user_id, signatures):
+ with await self._cross_signing_id_gen.get_next() as stream_id:
+ return await self.db_pool.runInteraction(
+ "add_e2e_cross_signing_key",
+ self._set_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ key,
+ stream_id,
+ )
+
+ async def store_e2e_cross_signing_signatures(
+ self, user_id: str, signatures: "Iterable[SignatureListItem]"
+ ) -> None:
"""Stores cross-signing signatures.
Args:
- user_id (str): the user who made the signatures
- signatures (iterable[SignatureListItem]): signatures to add
+ user_id: the user who made the signatures
+ signatures: signatures to add
"""
- return self.db_pool.simple_insert_many(
+ await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
[
{
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 484875f989..0b69aa6a94 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,14 +15,16 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Set, Tuple
from synapse.api.errors import StoreError
+from synapse.events import EventBase
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
@@ -30,57 +32,51 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- def get_auth_chain(self, event_ids, include_given=False):
+ async def get_auth_chain(
+ self, event_ids: Collection[str], include_given: bool = False
+ ) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
Returns:
list of events
"""
- return self.get_auth_chain_ids(
+ event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
- ).addCallback(self.get_events_as_list)
-
- def get_auth_chain_ids(
- self,
- event_ids: List[str],
- include_given: bool = False,
- ignore_events: Optional[Set[str]] = None,
- ):
+ )
+ return await self.get_events_as_list(event_ids)
+
+ async def get_auth_chain_ids(
+ self, event_ids: Collection[str], include_given: bool = False,
+ ) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
event_ids: state events
include_given: include the given events in result
- ignore_events: Set of events to exclude from the returned auth
- chain. This is useful if the caller will just discard the
- given events anyway, and saves us from figuring out their auth
- chains if not required.
Returns:
- list of event_ids
+ An awaitable which resolve to a list of event_ids
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids,
include_given,
- ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
- if ignore_events is None:
- ignore_events = set()
-
+ def _get_auth_chain_ids_txn(
+ self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
if include_given:
results = set(event_ids)
else:
results = set()
- base_sql = "SELECT auth_id FROM event_auth WHERE "
+ base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
@@ -92,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(base_sql + clause, args)
new_front.update(r[0] for r in txn)
- new_front -= ignore_events
new_front -= results
front = new_front
@@ -100,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
- def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -109,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
chain.
Returns:
- Deferred[Set[str]]
+ The set of the difference in auth chains.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
@@ -257,13 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
- def get_oldest_events_in_room(self, room_id):
- return self.db_pool.runInteraction(
- "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
- )
-
- def get_oldest_events_with_depth_in_room(self, room_id):
- return self.db_pool.runInteraction(
+ async def get_oldest_events_with_depth_in_room(self, room_id):
+ return await self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
@@ -303,15 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
return max(row["depth"] for row in rows)
- def _get_oldest_events_in_room_txn(self, txn, room_id):
- return self.db_pool.simple_select_onecol_txn(
- txn,
- table="event_backward_extremities",
- keyvalues={"room_id": room_id},
- retcol="event_id",
- )
-
- def get_prev_events_for_room(self, room_id: str):
+ async def get_prev_events_for_room(self, room_id: str) -> List[str]:
"""
Gets a subset of the current forward extremities in the given room.
@@ -319,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
events which refer to hundreds of prev_events.
Args:
- room_id (str): room_id
+ room_id: room_id
Returns:
- Deferred[List[str]]: the event ids of the forward extremites
+ The event ids of the forward extremities.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
@@ -346,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return [row[0] for row in txn]
- def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
+ async def get_rooms_with_many_extremities(
+ self, min_count: int, limit: int, room_id_filter: Iterable[str]
+ ) -> List[str]:
"""Get the top rooms with at least N extremities.
Args:
- min_count (int): The minimum number of extremities
- limit (int): The maximum number of rooms to return.
- room_id_filter (iterable[str]): room_ids to exclude from the results
+ min_count: The minimum number of extremities
+ limit: The maximum number of rooms to return.
+ room_id_filter: room_ids to exclude from the results
Returns:
- Deferred[list]: At most `limit` room IDs that have at least
- `min_count` extremities, sorted by extremity count.
+ At most `limit` room IDs that have at least `min_count` extremities,
+ sorted by extremity count.
"""
def _get_rooms_with_many_extremities_txn(txn):
@@ -381,23 +365,23 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, query_args)
return [room_id for room_id, in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
)
@cached(max_entries=5000, iterable=True)
- def get_latest_event_ids_in_room(self, room_id):
- return self.db_pool.simple_select_onecol(
+ async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
desc="get_latest_event_ids_in_room",
)
- def get_min_depth(self, room_id):
- """ For hte given room, get the minimum depth we have seen for it.
+ async def get_min_depth(self, room_id: str) -> int:
+ """For the given room, get the minimum depth we have seen for it.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
@@ -412,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return int(min_depth) if min_depth is not None else None
- def get_forward_extremeties_for_room(self, room_id, stream_ordering):
+ async def get_forward_extremeties_for_room(
+ self, room_id: str, stream_ordering: int
+ ) -> List[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -420,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
stream_orderings from that point.
Args:
- room_id (str):
- stream_ordering (int):
+ room_id:
+ stream_ordering:
Returns:
- deferred, which resolves to a list of event_ids
+ A list of event_ids
"""
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
@@ -440,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering)
- return self._get_forward_extremeties_for_room(room_id, stream_ordering)
+ return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
- def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
+ async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -468,31 +454,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
- def get_backfill_events(self, room_id, event_list, limit):
+ async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
Args:
- txn
- room_id (str)
- event_list (list)
- limit (int)
+ room_id
+ event_list
+ limit
"""
- return (
- self.db_pool.runInteraction(
- "get_backfill_events",
- self._get_backfill_events,
- room_id,
- event_list,
- limit,
- )
- .addCallback(self.get_events_as_list)
- .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
+ event_ids = await self.db_pool.runInteraction(
+ "get_backfill_events",
+ self._get_backfill_events,
+ room_id,
+ event_list,
+ limit,
)
+ events = await self.get_events_as_list(event_ids)
+ return sorted(events, key=lambda e: -e.depth)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -553,8 +536,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
- events = await self.get_events_as_list(ids)
- return events
+ return await self.get_events_as_list(ids)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@@ -652,8 +634,8 @@ class EventFederationStore(EventFederationWorkerStore):
_delete_old_forward_extrem_cache_txn,
)
- def clean_room_for_join(self, room_id):
- return self.db_pool.runInteraction(
+ async def clean_room_for_join(self, room_id):
+ return await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7c246d3e4c..5233ed83e2 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,13 +15,15 @@
# limitations under the License.
import logging
-from typing import List
+from typing import Dict, List, Optional, Tuple, Union
+
+import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -86,83 +88,107 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3
self._rotate_count = 10000
- @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
- def get_unread_event_push_actions_by_room_for_user(
- self, room_id, user_id, last_read_event_id
- ):
- ret = yield self.db_pool.runInteraction(
+ @cached(num_args=3, tree=True, max_entries=5000)
+ async def get_unread_event_push_actions_by_room_for_user(
+ self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+ ) -> Dict[str, int]:
+ """Get the notification count, the highlight count and the unread message count
+ for a given user in a given room after the given read receipt.
+
+ Note that this function assumes the user to be a current member of the room,
+ since it's either called by the sync handler to handle joined room entries, or by
+ the HTTP pusher to calculate the badge of unread joined rooms.
+
+ Args:
+ room_id: The room to retrieve the counts in.
+ user_id: The user to retrieve the counts for.
+ last_read_event_id: The event associated with the latest read receipt for
+ this user in this room. None if no receipt for this user in this room.
+
+ Returns
+ A dict containing the counts mentioned earlier in this docstring,
+ respectively under the keys "notify_count", "highlight_count" and
+ "unread_count".
+ """
+ return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
- return ret
def _get_unread_counts_by_receipt_txn(
- self, txn, room_id, user_id, last_read_event_id
+ self, txn, room_id, user_id, last_read_event_id,
):
- sql = (
- "SELECT stream_ordering"
- " FROM events"
- " WHERE room_id = ? AND event_id = ?"
- )
- txn.execute(sql, (room_id, last_read_event_id))
- results = txn.fetchall()
- if len(results) == 0:
- return {"notify_count": 0, "highlight_count": 0}
+ stream_ordering = None
- stream_ordering = results[0][0]
+ if last_read_event_id is not None:
+ stream_ordering = self.get_stream_id_for_event_txn(
+ txn, last_read_event_id, allow_none=True,
+ )
+
+ if stream_ordering is None:
+ # Either last_read_event_id is None, or it's an event we don't have (e.g.
+ # because it's been purged), in which case retrieve the stream ordering for
+ # the latest membership event from this user in this room (which we assume is
+ # a join).
+ event_id = self.db_pool.simple_select_one_onecol_txn(
+ txn=txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ retcol="event_id",
+ )
+
+ stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
-
- # First get number of notifications.
- # We don't need to put a notif=1 clause as all rows always have
- # notif=1
sql = (
- "SELECT count(*)"
+ "SELECT"
+ " COUNT(CASE WHEN notif = 1 THEN 1 END),"
+ " COUNT(CASE WHEN highlight = 1 THEN 1 END),"
+ " COUNT(CASE WHEN unread = 1 THEN 1 END)"
" FROM event_push_actions ea"
- " WHERE"
- " user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
+ " WHERE user_id = ?"
+ " AND room_id = ?"
+ " AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
- notify_count = row[0] if row else 0
+
+ (notif_count, highlight_count, unread_count) = (0, 0, 0)
+
+ if row:
+ (notif_count, highlight_count, unread_count) = row
txn.execute(
"""
- SELECT notif_count FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
- """,
+ SELECT notif_count, unread_count FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+ """,
(room_id, user_id, stream_ordering),
)
- rows = txn.fetchall()
- if rows:
- notify_count += rows[0][0]
+ row = txn.fetchone()
- # Now get the number of highlights
- sql = (
- "SELECT count(*)"
- " FROM event_push_actions ea"
- " WHERE"
- " highlight = 1"
- " AND user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
- )
+ if row:
+ notif_count += row[0]
- txn.execute(sql, (user_id, room_id, stream_ordering))
- row = txn.fetchone()
- highlight_count = row[0] if row else 0
+ if row[1] is not None:
+ # The unread_count column of event_push_summary is NULLable, so we need
+ # to make sure we don't try increasing the unread counts if it's NULL
+ # for this row.
+ unread_count += row[1]
- return {"notify_count": notify_count, "highlight_count": highlight_count}
+ return {
+ "notify_count": notif_count,
+ "unread_count": unread_count,
+ "highlight_count": highlight_count,
+ }
async def get_push_action_users_in_range(
self, min_stream_ordering, max_stream_ordering
@@ -170,7 +196,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def f(txn):
sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
- " stream_ordering >= ? AND stream_ordering <= ?"
+ " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1"
)
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
@@ -223,6 +249,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -251,6 +278,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -325,6 +353,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -353,6 +382,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -384,62 +414,66 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Now return the first `limit`
return notifs[:limit]
- def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
+ async def get_if_maybe_push_in_range_for_user(
+ self, user_id: str, min_stream_ordering: int
+ ) -> bool:
"""A fast check to see if there might be something to push for the
user since the given stream ordering. May return false positives.
Useful to know whether to bother starting a pusher on start up or not.
Args:
- user_id (str)
- min_stream_ordering (int)
+ user_id
+ min_stream_ordering
Returns:
- Deferred[bool]: True if there may be push to process, False if
- there definitely isn't.
+ True if there may be push to process, False if there definitely isn't.
"""
def _get_if_maybe_push_in_range_for_user_txn(txn):
sql = """
SELECT 1 FROM event_push_actions
- WHERE user_id = ? AND stream_ordering > ?
+ WHERE user_id = ? AND stream_ordering > ? AND notif = 1
LIMIT 1
"""
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
- async def add_push_actions_to_staging(self, event_id, user_id_actions):
+ async def add_push_actions_to_staging(
+ self,
+ event_id: str,
+ user_id_actions: Dict[str, List[Union[dict, str]]],
+ count_as_unread: bool,
+ ) -> None:
"""Add the push actions for the event to the push action staging area.
Args:
- event_id (str)
- user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
- user_id to list of push actions, where an action can either be
- a string or dict.
-
- Returns:
- Deferred
+ event_id
+ user_id_actions: A mapping of user_id to list of push actions, where
+ an action can either be a string or dict.
+ count_as_unread: Whether this event should increment unread counts.
"""
-
if not user_id_actions:
return
# This is a helper function for generating the necessary tuple that
- # can be used to inert into the `event_push_actions_staging` table.
+ # can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(user_id, actions):
is_highlight = 1 if _action_has_highlight(actions) else 0
+ notif = 1 if "notify" in actions else 0
return (
event_id, # event_id column
user_id, # user_id column
_serialize_action(actions, is_highlight), # actions column
- 1, # notif column
+ notif, # notif column
is_highlight, # highlight column
+ int(count_as_unread), # unread column
)
def _add_push_actions_to_staging_txn(txn):
@@ -448,8 +482,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql = """
INSERT INTO event_push_actions_staging
- (event_id, user_id, actions, notif, highlight)
- VALUES (?, ?, ?, ?, ?)
+ (event_id, user_id, actions, notif, highlight, unread)
+ VALUES (?, ?, ?, ?, ?, ?)
"""
txn.executemany(
@@ -508,7 +542,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
)
- def find_first_stream_ordering_after_ts(self, ts):
+ async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
"""Gets the stream ordering corresponding to a given timestamp.
Specifically, finds the stream_ordering of the first event that was
@@ -517,13 +551,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
relatively slow.
Args:
- ts (int): timestamp in millis
+ ts: timestamp in millis
Returns:
- Deferred[int]: stream ordering of the first event received on/after
- the timestamp
+ stream ordering of the first event received on/after the timestamp
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@@ -611,7 +644,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.stream_ordering > ?"
+ " WHERE ep.stream_ordering > ? AND notif = 1"
" ORDER BY ep.stream_ordering ASC"
" LIMIT 1"
)
@@ -675,6 +708,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" FROM event_push_actions epa, events e"
" WHERE epa.event_id = e.event_id"
" AND epa.user_id = ? %s"
+ " AND epa.notif = 1"
" ORDER BY epa.stream_ordering DESC"
" LIMIT ?" % (before_clause,)
)
@@ -814,24 +848,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
- coalesce(old.notif_count, 0) + upd.notif_count,
+ coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering,
old.user_id
FROM (
- SELECT user_id, room_id, count(*) as notif_count,
+ SELECT user_id, room_id, count(*) as cnt,
max(stream_ordering) as stream_ordering
FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0
+ AND %s = 1
GROUP BY user_id, room_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
- txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
- rows = txn.fetchall()
+ # First get the count of unread messages.
+ txn.execute(
+ sql % ("unread_count", "unread"),
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ # We need to merge results from the two requests (the one that retrieves the
+ # unread count and the one that retrieves the notifications count) into a single
+ # object because we might not have the same amount of rows in each of them. To do
+ # this, we use a dict indexed on the user ID and room ID to make it easier to
+ # populate.
+ summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
+ for row in txn:
+ summaries[(row[0], row[1])] = _EventPushSummary(
+ unread_count=row[2],
+ stream_ordering=row[3],
+ old_user_id=row[4],
+ notif_count=0,
+ )
+
+ # Then get the count of notifications.
+ txn.execute(
+ sql % ("notif_count", "notif"),
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ for row in txn:
+ if (row[0], row[1]) in summaries:
+ summaries[(row[0], row[1])].notif_count = row[2]
+ else:
+ # Because the rules on notifying are different than the rules on marking
+ # a message unread, we might end up with messages that notify but aren't
+ # marked unread, so we might not have a summary for this (user, room)
+ # tuple to complete.
+ summaries[(row[0], row[1])] = _EventPushSummary(
+ unread_count=0,
+ stream_ordering=row[3],
+ old_user_id=row[4],
+ notif_count=row[2],
+ )
- logger.info("Rotating notifications, handling %d rows", len(rows))
+ logger.info("Rotating notifications, handling %d rows", len(summaries))
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
@@ -841,22 +914,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
table="event_push_summary",
values=[
{
- "user_id": row[0],
- "room_id": row[1],
- "notif_count": row[2],
- "stream_ordering": row[3],
+ "user_id": user_id,
+ "room_id": room_id,
+ "notif_count": summary.notif_count,
+ "unread_count": summary.unread_count,
+ "stream_ordering": summary.stream_ordering,
}
- for row in rows
- if row[4] is None
+ for ((user_id, room_id), summary) in summaries.items()
+ if summary.old_user_id is None
],
)
txn.executemany(
"""
- UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+ UPDATE event_push_summary
+ SET notif_count = ?, unread_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
""",
- ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
+ (
+ (
+ summary.notif_count,
+ summary.unread_count,
+ summary.stream_ordering,
+ user_id,
+ room_id,
+ )
+ for ((user_id, room_id), summary) in summaries.items()
+ if summary.old_user_id is not None
+ ),
)
txn.execute(
@@ -882,3 +967,15 @@ def _action_has_highlight(actions):
pass
return False
+
+
+@attr.s
+class _EventPushSummary:
+ """Summary of pending event push actions for a given user in a given room.
+ Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+ """
+
+ unread_count = attr.ib(type=int)
+ stream_ordering = attr.ib(type=int)
+ old_user_id = attr.ib(type=str)
+ notif_count = attr.ib(type=int)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1a68bf32cb..b3d27a2ee7 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,13 +17,11 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
import attr
from prometheus_client import Counter
-from twisted.internet import defer
-
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
@@ -113,15 +111,14 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
- @defer.inlineCallbacks
- def _persist_events_and_state_updates(
+ async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
- ):
+ ) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -136,7 +133,7 @@ class PersistEventsStore:
backfilled
Returns:
- Deferred: resolves when the events have been persisted
+ Resolves when the events have been persisted
"""
# We want to calculate the stream orderings as late as possible, as
@@ -156,11 +153,11 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = await self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
@@ -168,7 +165,7 @@ class PersistEventsStore:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
@@ -206,16 +203,15 @@ class PersistEventsStore:
(room_id,), list(latest_event_ids)
)
- @defer.inlineCallbacks
- def _get_events_which_are_prevs(self, event_ids):
+ async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
Args:
- event_ids (Iterable[str]): event ids to filter
+ event_ids: event ids to filter
Returns:
- Deferred[List[str]]: filtered event ids
+ Filtered event ids
"""
results = []
@@ -240,14 +236,13 @@ class PersistEventsStore:
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
return results
- @defer.inlineCallbacks
- def _get_prevs_before_rejected(self, event_ids):
+ async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
"""Get soft-failed ancestors to remove from the extremities.
Given a set of events, find all those that have been soft-failed or
@@ -259,11 +254,11 @@ class PersistEventsStore:
are separated by soft failed events.
Args:
- event_ids (Iterable[str]): Events to find prev events for. Note
- that these must have already been persisted.
+ event_ids: Events to find prev events for. Note that these must have
+ already been persisted.
Returns:
- Deferred[set[str]]
+ The previous events.
"""
# The set of event_ids to return. This includes all soft-failed events
@@ -304,7 +299,7 @@ class PersistEventsStore:
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
@@ -1301,9 +1296,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
- topological_ordering, notif, highlight
+ topological_ordering, notif, highlight, unread
)
- SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
FROM event_push_actions_staging
WHERE event_id = ?
"""
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 35a0e09e3c..e53c6373a8 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
@@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="NOT have_censored",
)
- @defer.inlineCallbacks
- def _background_reindex_fields_sender(self, progress, batch_size):
+ async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
- @defer.inlineCallbacks
- def _background_reindex_origin_server_ts(self, progress, batch_size):
+ async def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows_to_update)
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
- @defer.inlineCallbacks
- def _cleanup_extremities_bg_update(self, progress, batch_size):
+ async def _cleanup_extremities_bg_update(self, progress, batch_size):
"""Background update to clean out extremities that should have been
deleted previously.
@@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(original_set)
- num_handled = yield self.db_pool.runInteraction(
+ num_handled = await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
return num_handled
- @defer.inlineCallbacks
- def _redactions_received_ts(self, progress, batch_size):
+ async def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions.
"""
last_event_id = progress.get("last_event_id", "")
@@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
- count = yield self.db_pool.runInteraction(
+ count = await self.db_pool.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
- yield self.db_pool.updates._end_background_update("redactions_received_ts")
+ await self.db_pool.updates._end_background_update("redactions_received_ts")
return count
- @defer.inlineCallbacks
- def _event_fix_redactions_bytes(self, progress, batch_size):
+ async def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
@@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute("DROP INDEX redactions_censored_redacts")
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
- yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
+ await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
return 1
- @defer.inlineCallbacks
- def _event_store_labels(self, progress, batch_size):
+ async def _event_store_labels(self, progress, batch_size):
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
@@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return nbrows
- num_rows = yield self.db_pool.runInteraction(
+ num_rows = await self.db_pool.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
if not num_rows:
- yield self.db_pool.updates._end_background_update("event_store_labels")
+ await self.db_pool.updates._end_background_update("event_store_labels")
return num_rows
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 755b7a2a85..a7a73cc3d8 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names
+from typing_extensions import Literal
from twisted.internet import defer
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersions,
)
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -112,69 +113,58 @@ class EventsWorkerStore(SQLBaseStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
- self._stream_id_gen.advance(token)
+ self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
- self._backfill_id_gen.advance(-token)
+ self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_rows(stream_name, instance_name, token, rows)
- def get_received_ts(self, event_id):
+ async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event.
Raises an exception for unknown events.
Args:
- event_id (str)
+ event_id: The event ID to query.
Returns:
- Deferred[int|None]: Timestamp in milliseconds, or None for events
- that were persisted before received_ts was implemented.
+ Timestamp in milliseconds, or None for events that were persisted
+ before received_ts was implemented.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
desc="get_received_ts",
)
- def get_received_ts_by_stream_pos(self, stream_ordering):
- """Given a stream ordering get an approximate timestamp of when it
- happened.
-
- This is done by simply taking the received ts of the first event that
- has a stream ordering greater than or equal to the given stream pos.
- If none exists returns the current time, on the assumption that it must
- have happened recently.
-
- Args:
- stream_ordering (int)
-
- Returns:
- Deferred[int]
- """
-
- def _get_approximate_received_ts_txn(txn):
- sql = """
- SELECT received_ts FROM events
- WHERE stream_ordering >= ?
- LIMIT 1
- """
-
- txn.execute(sql, (stream_ordering,))
- row = txn.fetchone()
- if row and row[0]:
- ts = row[0]
- else:
- ts = self.clock.time_msec()
-
- return ts
+ # Inform mypy that if allow_none is False (the default) then get_event
+ # always returns an EventBase.
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[False] = False,
+ check_room_id: Optional[str] = None,
+ ) -> EventBase:
+ ...
- return self.db_pool.runInteraction(
- "get_approximate_received_ts", _get_approximate_received_ts_txn
- )
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[True] = False,
+ check_room_id: Optional[str] = None,
+ ) -> Optional[EventBase]:
+ ...
- @defer.inlineCallbacks
- def get_event(
+ async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -182,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
- ):
+ ) -> Optional[EventBase]:
"""Get an event from the database by event_id.
Args:
@@ -207,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
- Deferred[EventBase|None]
+ The event, or None if the event was not found.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[event_id],
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -230,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
return event
- @defer.inlineCallbacks
- def get_events(
+ async def get_events(
self,
- event_ids: List[str],
+ event_ids: Iterable[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> Dict[str, EventBase]:
"""Get events from the database
Args:
@@ -256,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
omits rejeted events from the response.
Returns:
- Deferred : Dict from event_id to event.
+ A mapping from event_id to event.
"""
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -267,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
- @defer.inlineCallbacks
- def get_events_as_list(
+ async def get_events_as_list(
self,
- event_ids: List[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> List[EventBase]:
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
@@ -295,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
omits rejected events from the response.
Returns:
- Deferred[list[EventBase]]: List of events fetched from the database. The
- events are in the same order as `event_ids` arg.
+ List of events fetched from the database. The events are in the same
+ order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
@@ -306,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
- event_entry_map = yield self._get_events_from_cache_or_db(
+ event_entry_map = await self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -341,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
- event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+ event_map = await self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -407,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
if get_prev_content:
if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
+ prev = await self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
@@ -419,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
return events
- @defer.inlineCallbacks
- def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -435,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
@@ -453,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
- missing_events = yield self._get_events_from_db(
+ missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
@@ -561,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
- @defer.inlineCallbacks
- def _get_events_from_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
@@ -576,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
"""
@@ -584,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
events_to_fetch = event_ids
while events_to_fetch:
- row_map = yield self._enqueue_events(events_to_fetch)
+ row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
@@ -610,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
- d = db_to_json(row["json"])
- internal_metadata = db_to_json(row["internal_metadata"])
+ # If the event or metadata cannot be parsed, log the error and act
+ # as if the event is unknown.
+ try:
+ d = db_to_json(row["json"])
+ except ValueError:
+ logger.error("Unable to parse json from event: %s", event_id)
+ continue
+ try:
+ internal_metadata = db_to_json(row["internal_metadata"])
+ except ValueError:
+ logger.error(
+ "Unable to parse internal_metadata from event: %s", event_id
+ )
+ continue
format_version = row["format_version"]
if format_version is None:
@@ -622,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row["room_version_id"]
if not room_version_id:
- # this should only happen for out-of-band membership events
- if not internal_metadata.get("out_of_band_membership"):
- logger.warning(
- "Room %s for event %s is unknown", d["room_id"], event_id
+ # this should only happen for out-of-band membership events which
+ # arrived before #6983 landed. For all other events, we should have
+ # an entry in the 'rooms' table.
+ #
+ # However, the 'out_of_band_membership' flag is unreliable for older
+ # invites, so just accept it for all membership events.
+ #
+ if d["type"] != EventTypes.Member:
+ raise Exception(
+ "Room %s for event %s is unknown" % (d["room_id"], event_id)
)
- continue
- # take a wild stab at the room version based on the event format
+ # so, assuming this is an out-of-band-invite that arrived before #6983
+ # landed, we know that the room version must be v5 or earlier (because
+ # v6 hadn't been invented at that point, so invites from such rooms
+ # would have been rejected.)
+ #
+ # The main reason we need to know the room version here (other than
+ # choosing the right python Event class) is in case the event later has
+ # to be redacted - and all the room versions up to v5 used the same
+ # redaction algorithm.
+ #
+ # So, the following approximations should be adequate.
+
if format_version == EventFormatVersions.V1:
+ # if it's event format v1 then it must be room v1 or v2
room_version = RoomVersions.V1
elif format_version == EventFormatVersions.V2:
+ # if it's event format v2 then it must be room v3
room_version = RoomVersions.V3
else:
+ # if it's event format v3 then it must be room v4 or v5
room_version = RoomVersions.V5
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
@@ -686,8 +703,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- @defer.inlineCallbacks
- def _enqueue_events(self, events):
+ async def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -696,7 +712,7 @@ class EventsWorkerStore(SQLBaseStore):
events (Iterable[str]): events to be fetched.
Returns:
- Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+ Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
"""
@@ -719,7 +735,7 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
- row_map = yield events_d
+ row_map = await events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map
@@ -807,20 +823,24 @@ class EventsWorkerStore(SQLBaseStore):
return event_dict
- def _maybe_redact_event_row(self, original_ev, redactions, event_map):
+ def _maybe_redact_event_row(
+ self,
+ original_ev: EventBase,
+ redactions: Iterable[str],
+ event_map: Dict[str, EventBase],
+ ) -> Optional[EventBase]:
"""Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted
event.
Args:
- original_ev (EventBase):
- redactions (iterable[str]): list of event ids of potential redaction events
- event_map (dict[str, EventBase]): other events which have been fetched, in
- which we can look up the redaaction events. Map from event id to event.
+ original_ev: The original event.
+ redactions: list of event ids of potential redaction events
+ event_map: other events which have been fetched, in which we can
+ look up the redaaction events. Map from event id to event.
Returns:
- Deferred[EventBase|None]: if the event should be redacted, a pruned
- event object. Otherwise, None.
+ If the event should be redacted, a pruned event object. Otherwise, None.
"""
if original_ev.type == "m.room.create":
# we choose to ignore redactions of m.room.create events.
@@ -878,12 +898,11 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- @defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
@@ -894,15 +913,14 @@ class EventsWorkerStore(SQLBaseStore):
return {r["event_id"] for r in rows}
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
+ async def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
- Deferred[set[str]]: The events we have already seen.
+ set[str]: The events we have already seen.
"""
results = set()
@@ -918,41 +936,11 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
- def _get_total_state_event_counts_txn(self, txn, room_id):
- """
- See get_total_state_event_counts.
- """
- # We join against the events table as that has an index on room_id
- sql = """
- SELECT COUNT(*) FROM state_events
- INNER JOIN events USING (room_id, event_id)
- WHERE room_id=?
- """
- txn.execute(sql, (room_id,))
- row = txn.fetchone()
- return row[0] if row else 0
-
- def get_total_state_event_counts(self, room_id):
- """
- Gets the total number of state events in a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[int]
- """
- return self.db_pool.runInteraction(
- "get_total_state_event_counts",
- self._get_total_state_event_counts_txn,
- room_id,
- )
-
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
@@ -962,24 +950,23 @@ class EventsWorkerStore(SQLBaseStore):
row = txn.fetchone()
return row[0] if row else 0
- def get_current_state_event_counts(self, room_id):
+ async def get_current_state_event_counts(self, room_id: str) -> int:
"""
Gets the current number of state events in a room.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- Deferred[int]
+ The current number of state events.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn,
room_id,
)
- @defer.inlineCallbacks
- def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -990,9 +977,9 @@ class EventsWorkerStore(SQLBaseStore):
room_id (str)
Returns:
- Deferred[dict[str:int]] of complexity version to complexity.
+ dict[str:int] of complexity version to complexity.
"""
- state_events = yield self.get_current_state_event_counts(room_id)
+ state_events = await self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
@@ -1008,7 +995,9 @@ class EventsWorkerStore(SQLBaseStore):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
- def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+ async def get_all_new_forward_event_rows(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple]:
"""Returns new events, for the Events replication stream
Args:
@@ -1016,7 +1005,7 @@ class EventsWorkerStore(SQLBaseStore):
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
- Returns: Deferred[List[Tuple]]
+ Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
@@ -1037,18 +1026,20 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
- def get_ex_outlier_stream_rows(self, last_id, current_id):
+ async def get_ex_outlier_stream_rows(
+ self, last_id: int, current_id: int
+ ) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
- Returns: Deferred[List[Tuple]]
+ Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
@@ -1071,7 +1062,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
)
@@ -1222,97 +1213,6 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- @cached(num_args=5, max_entries=10)
- def get_all_new_events(
- self,
- last_backfill_id,
- last_forward_id,
- current_backfill_id,
- current_forward_id,
- limit,
- ):
- """Get all the new events that have arrived at the server either as
- new events or as backfilled events"""
- have_backfill_events = last_backfill_id != current_backfill_id
- have_forward_events = last_forward_id != current_forward_id
-
- if not have_backfill_events and not have_forward_events:
- return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
- def get_all_new_events_txn(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- if have_forward_events:
- txn.execute(sql, (last_forward_id, current_forward_id, limit))
- new_forward_events = txn.fetchall()
-
- if len(new_forward_events) == limit:
- upper_bound = new_forward_events[-1][0]
- else:
- upper_bound = current_forward_id
-
- sql = (
- "SELECT event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (last_forward_id, upper_bound))
- forward_ex_outliers = txn.fetchall()
- else:
- new_forward_events = []
- forward_ex_outliers = []
-
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- )
- if have_backfill_events:
- txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
- new_backfill_events = txn.fetchall()
-
- if len(new_backfill_events) == limit:
- upper_bound = new_backfill_events[-1][0]
- else:
- upper_bound = current_backfill_id
-
- sql = (
- "SELECT -event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (-last_backfill_id, -upper_bound))
- backward_ex_outliers = txn.fetchall()
- else:
- new_backfill_events = []
- backward_ex_outliers = []
-
- return AllNewEventsResult(
- new_forward_events,
- new_backfill_events,
- forward_ex_outliers,
- backward_ex_outliers,
- )
-
- return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
-
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
@@ -1320,9 +1220,9 @@ class EventsWorkerStore(SQLBaseStore):
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
- @cachedInlineCallbacks(max_entries=5000)
- def get_event_ordering(self, event_id):
- res = yield self.db_pool.simple_select_one(
+ @cached(max_entries=5000)
+ async def get_event_ordering(self, event_id):
+ res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@@ -1334,11 +1234,11 @@ class EventsWorkerStore(SQLBaseStore):
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
- def get_next_event_to_expire(self):
+ async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
table, or None if there's no more event to expire.
- Returns: Deferred[Optional[Tuple[str, int]]]
+ Returns:
A tuple containing the event ID as its first element and an expiry timestamp
as its second one, if there's at least one row in the event_expiry table.
None otherwise.
@@ -1354,17 +1254,6 @@ class EventsWorkerStore(SQLBaseStore):
return txn.fetchone()
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
-
-
-AllNewEventsResult = namedtuple(
- "AllNewEventsResult",
- [
- "new_forward_events",
- "new_backfill_events",
- "forward_ex_outliers",
- "backward_ex_outliers",
- ],
-)
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 45a1760170..d2f5b9a502 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
- def add_user_filter(self, user_localpart, user_filter):
+ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
@@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
return filter_id
- return self.db_pool.runInteraction("add_user_filter", _do_txn)
+ return await self.db_pool.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 380db3a3f3..ccfbb2135e 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore):
- def get_group(self, group_id):
- return self.db_pool.simple_select_one(
+ async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -44,31 +44,35 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group",
)
- def get_users_in_group(self, group_id, include_private=False):
+ async def get_users_in_group(
+ self, group_id: str, include_private: bool = False
+ ) -> List[Dict[str, Any]]:
# TODO: Pagination
keyvalues = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
- return self.db_pool.simple_select_list(
+ return await self.db_pool.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
desc="get_users_in_group",
)
- def get_invited_users_in_group(self, group_id):
+ async def get_invited_users_in_group(self, group_id: str) -> List[str]:
# TODO: Pagination
- return self.db_pool.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
desc="get_invited_users_in_group",
)
- def get_rooms_in_group(self, group_id: str, include_private: bool = False):
+ async def get_rooms_in_group(
+ self, group_id: str, include_private: bool = False
+ ) -> List[Dict[str, Union[str, bool]]]:
"""Retrieve the rooms that belong to a given group. Does not return rooms that
lack members.
@@ -77,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore):
include_private: Whether to return private rooms in results
Returns:
- Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
- form of:
+ A list of dictionaries, each in the form of:
{
"room_id": "!a_room_id:example.com", # The ID of the room
@@ -115,13 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore):
for room_id, is_public in txn
]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_in_group", _get_rooms_in_group_txn
)
- def get_rooms_for_summary_by_category(
+ async def get_rooms_for_summary_by_category(
self, group_id: str, include_private: bool = False,
- ):
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Get the rooms and categories that should be included in a summary request
Args:
@@ -129,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore):
include_private: Whether to return private rooms in results
Returns:
- Deferred[Tuple[List, Dict]]: A tuple containing:
+ A tuple containing:
* A list of dictionaries with the keys:
* "room_id": str, the room ID
@@ -205,7 +208,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return rooms, categories
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
@@ -265,25 +268,25 @@ class GroupServerWorkerStore(SQLBaseStore):
return role
- def get_local_groups_for_room(self, room_id):
+ async def get_local_groups_for_room(self, room_id: str) -> List[str]:
"""Get all of the local group that contain a given room
Args:
- room_id (str): The ID of a room
+ room_id: The ID of a room
Returns:
- Deferred[list[str]]: A twisted.Deferred containing a list of group ids
- containing this room
+ A list of group ids containing this room
"""
- return self.db_pool.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="group_rooms",
keyvalues={"room_id": room_id},
retcol="group_id",
desc="get_local_groups_for_room",
)
- def get_users_for_summary_by_role(self, group_id, include_private=False):
+ async def get_users_for_summary_by_role(self, group_id, include_private=False):
"""Get the users and roles that should be included in a summary request
- Returns ([users], [roles])
+ Returns:
+ ([users], [roles])
"""
def _get_users_for_summary_txn(txn):
@@ -337,21 +340,24 @@ class GroupServerWorkerStore(SQLBaseStore):
return users, roles
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
- def is_user_in_group(self, user_id, group_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
+ result = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
allow_none=True,
desc="is_user_in_group",
- ).addCallback(lambda r: bool(r))
+ )
+ return bool(result)
- def is_user_admin_in_group(self, group_id, user_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_user_admin_in_group(
+ self, group_id: str, user_id: str
+ ) -> Optional[bool]:
+ return await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@@ -359,10 +365,12 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="is_user_admin_in_group",
)
- def is_user_invited_to_local_group(self, group_id, user_id):
+ async def is_user_invited_to_local_group(
+ self, group_id: str, user_id: str
+ ) -> Optional[bool]:
"""Has the group server invited a user?
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@@ -370,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True,
)
- def get_users_membership_info_in_group(self, group_id, user_id):
+ async def get_users_membership_info_in_group(self, group_id, user_id):
"""Get a dict describing the membership of a user in a group.
Example if joined:
@@ -381,7 +389,8 @@ class GroupServerWorkerStore(SQLBaseStore):
"is_privileged": False,
}
- Returns an empty dict if the user is not join/invite/etc
+ Returns:
+ An empty dict if the user is not join/invite/etc
"""
def _get_users_membership_in_group_txn(txn):
@@ -413,21 +422,21 @@ class GroupServerWorkerStore(SQLBaseStore):
return {}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
)
- def get_publicised_groups_for_user(self, user_id):
+ async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
"""Get all groups a user is publicising
"""
- return self.db_pool.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id",
desc="get_publicised_groups_for_user",
)
- def get_attestations_need_renewals(self, valid_until_ms):
+ async def get_attestations_need_renewals(self, valid_until_ms):
"""Get all attestations that need to be renewed until givent time
"""
@@ -439,7 +448,7 @@ class GroupServerWorkerStore(SQLBaseStore):
txn.execute(sql, (valid_until_ms,))
return self.db_pool.cursor_to_dict(txn)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
@@ -461,15 +470,15 @@ class GroupServerWorkerStore(SQLBaseStore):
return None
- def get_joined_groups(self, user_id):
- return self.db_pool.simple_select_onecol(
+ async def get_joined_groups(self, user_id: str) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id",
desc="get_joined_groups",
)
- def get_all_groups_for_user(self, user_id, now_token):
+ async def get_all_groups_for_user(self, user_id, now_token):
def _get_all_groups_for_user_txn(txn):
sql = """
SELECT group_id, type, membership, u.content
@@ -489,7 +498,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in txn
]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
@@ -580,22 +589,41 @@ class GroupServerWorkerStore(SQLBaseStore):
class GroupServerStore(GroupServerWorkerStore):
- def set_group_join_policy(self, group_id, join_policy):
+ async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
"""Set the join policy of a group.
join_policy can be one of:
* "invite"
* "open"
"""
- return self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy},
desc="set_group_join_policy",
)
- def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
- return self.db_pool.runInteraction(
+ async def add_room_to_summary(
+ self,
+ group_id: str,
+ room_id: str,
+ category_id: str,
+ order: int,
+ is_public: Optional[bool],
+ ) -> None:
+ """Add (or update) room's entry in summary.
+
+ Args:
+ group_id
+ room_id
+ category_id: If not None then adds the category to the end of
+ the summary if its not already there.
+ order: If not None inserts the room at that position, e.g. an order
+ of 1 will put the room first. Otherwise, the room gets added to
+ the end.
+ is_public
+ """
+ await self.db_pool.runInteraction(
"add_room_to_summary",
self._add_room_to_summary_txn,
group_id,
@@ -606,18 +634,26 @@ class GroupServerStore(GroupServerWorkerStore):
)
def _add_room_to_summary_txn(
- self, txn, group_id, room_id, category_id, order, is_public
- ):
+ self,
+ txn,
+ group_id: str,
+ room_id: str,
+ category_id: str,
+ order: int,
+ is_public: Optional[bool],
+ ) -> None:
"""Add (or update) room's entry in summary.
Args:
- group_id (str)
- room_id (str)
- category_id (str): If not None then adds the category to the end of
- the summary if its not already there. [Optional]
- order (int): If not None inserts the room at that position, e.g.
- an order of 1 will put the room first. Otherwise, the room gets
- added to the end.
+ txn
+ group_id
+ room_id
+ category_id: If not None then adds the category to the end of
+ the summary if its not already there.
+ order: If not None inserts the room at that position, e.g. an order
+ of 1 will put the room first. Otherwise, the room gets added to
+ the end.
+ is_public
"""
room_in_group = self.db_pool.simple_select_one_onecol_txn(
txn,
@@ -722,11 +758,13 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
- def remove_room_from_summary(self, group_id, room_id, category_id):
+ async def remove_room_from_summary(
+ self, group_id: str, room_id: str, category_id: str
+ ) -> int:
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
- return self.db_pool.simple_delete(
+ return await self.db_pool.simple_delete(
table="group_summary_rooms",
keyvalues={
"group_id": group_id,
@@ -736,7 +774,13 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_room_from_summary",
)
- def upsert_group_category(self, group_id, category_id, profile, is_public):
+ async def upsert_group_category(
+ self,
+ group_id: str,
+ category_id: str,
+ profile: Optional[JsonDict],
+ is_public: Optional[bool],
+ ) -> None:
"""Add/update room category for group
"""
insertion_values = {}
@@ -752,7 +796,7 @@ class GroupServerStore(GroupServerWorkerStore):
else:
update_values["is_public"] = is_public
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
@@ -760,14 +804,20 @@ class GroupServerStore(GroupServerWorkerStore):
desc="upsert_group_category",
)
- def remove_group_category(self, group_id, category_id):
- return self.db_pool.simple_delete(
+ async def remove_group_category(self, group_id: str, category_id: str) -> int:
+ return await self.db_pool.simple_delete(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
desc="remove_group_category",
)
- def upsert_group_role(self, group_id, role_id, profile, is_public):
+ async def upsert_group_role(
+ self,
+ group_id: str,
+ role_id: str,
+ profile: Optional[JsonDict],
+ is_public: Optional[bool],
+ ) -> None:
"""Add/remove user role
"""
insertion_values = {}
@@ -783,7 +833,7 @@ class GroupServerStore(GroupServerWorkerStore):
else:
update_values["is_public"] = is_public
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
@@ -791,15 +841,34 @@ class GroupServerStore(GroupServerWorkerStore):
desc="upsert_group_role",
)
- def remove_group_role(self, group_id, role_id):
- return self.db_pool.simple_delete(
+ async def remove_group_role(self, group_id: str, role_id: str) -> int:
+ return await self.db_pool.simple_delete(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
desc="remove_group_role",
)
- def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
- return self.db_pool.runInteraction(
+ async def add_user_to_summary(
+ self,
+ group_id: str,
+ user_id: str,
+ role_id: str,
+ order: int,
+ is_public: Optional[bool],
+ ) -> None:
+ """Add (or update) user's entry in summary.
+
+ Args:
+ group_id
+ user_id
+ role_id: If not None then adds the role to the end of the summary if
+ its not already there.
+ order: If not None inserts the user at that position, e.g. an order
+ of 1 will put the user first. Otherwise, the user gets added to
+ the end.
+ is_public
+ """
+ await self.db_pool.runInteraction(
"add_user_to_summary",
self._add_user_to_summary_txn,
group_id,
@@ -810,18 +879,26 @@ class GroupServerStore(GroupServerWorkerStore):
)
def _add_user_to_summary_txn(
- self, txn, group_id, user_id, role_id, order, is_public
+ self,
+ txn,
+ group_id: str,
+ user_id: str,
+ role_id: str,
+ order: int,
+ is_public: Optional[bool],
):
"""Add (or update) user's entry in summary.
Args:
- group_id (str)
- user_id (str)
- role_id (str): If not None then adds the role to the end of
- the summary if its not already there. [Optional]
- order (int): If not None inserts the user at that position, e.g.
- an order of 1 will put the user first. Otherwise, the user gets
- added to the end.
+ txn
+ group_id
+ user_id
+ role_id: If not None then adds the role to the end of the summary if
+ its not already there.
+ order: If not None inserts the user at that position, e.g. an order
+ of 1 will put the user first. Otherwise, the user gets added to
+ the end.
+ is_public
"""
user_in_group = self.db_pool.simple_select_one_onecol_txn(
txn,
@@ -922,46 +999,47 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
- def remove_user_from_summary(self, group_id, user_id, role_id):
+ async def remove_user_from_summary(
+ self, group_id: str, user_id: str, role_id: str
+ ) -> int:
if role_id is None:
role_id = _DEFAULT_ROLE_ID
- return self.db_pool.simple_delete(
+ return await self.db_pool.simple_delete(
table="group_summary_users",
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
desc="remove_user_from_summary",
)
- def add_group_invite(self, group_id, user_id):
+ async def add_group_invite(self, group_id: str, user_id: str) -> None:
"""Record that the group server has invited a user
"""
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
)
- def add_user_to_group(
+ async def add_user_to_group(
self,
- group_id,
- user_id,
- is_admin=False,
- is_public=True,
- local_attestation=None,
- remote_attestation=None,
- ):
+ group_id: str,
+ user_id: str,
+ is_admin: bool = False,
+ is_public: bool = True,
+ local_attestation: dict = None,
+ remote_attestation: dict = None,
+ ) -> None:
"""Add a user to the group server.
Args:
- group_id (str)
- user_id (str)
- is_admin (bool)
- is_public (bool)
- local_attestation (dict): The attestation the GS created to give
- to the remote server. Optional if the user and group are on the
- same server
- remote_attestation (dict): The attestation given to GS by remote
+ group_id
+ user_id
+ is_admin
+ is_public
+ local_attestation: The attestation the GS created to give to the remote
server. Optional if the user and group are on the same server
+ remote_attestation: The attestation given to GS by remote server.
+ Optional if the user and group are on the same server
"""
def _add_user_to_group_txn(txn):
@@ -1004,9 +1082,9 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
- return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
+ await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
- def remove_user_from_group(self, group_id, user_id):
+ async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
def _remove_user_from_group_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@@ -1034,26 +1112,30 @@ class GroupServerStore(GroupServerWorkerStore):
keyvalues={"group_id": group_id, "user_id": user_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn
)
- def add_room_to_group(self, group_id, room_id, is_public):
- return self.db_pool.simple_insert(
+ async def add_room_to_group(
+ self, group_id: str, room_id: str, is_public: bool
+ ) -> None:
+ await self.db_pool.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
)
- def update_room_in_group_visibility(self, group_id, room_id, is_public):
- return self.db_pool.simple_update(
+ async def update_room_in_group_visibility(
+ self, group_id: str, room_id: str, is_public: bool
+ ) -> int:
+ return await self.db_pool.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
desc="update_room_in_group_visibility",
)
- def remove_room_from_group(self, group_id, room_id):
+ async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
def _remove_room_from_group_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@@ -1067,14 +1149,16 @@ class GroupServerStore(GroupServerWorkerStore):
keyvalues={"group_id": group_id, "room_id": room_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn
)
- def update_group_publicity(self, group_id, user_id, publicise):
+ async def update_group_publicity(
+ self, group_id: str, user_id: str, publicise: bool
+ ) -> None:
"""Update whether the user is publicising their membership of the group
"""
- return self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
@@ -1181,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with self._group_updates_id_gen.get_next() as next_id:
+ with await self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
@@ -1213,20 +1297,24 @@ class GroupServerStore(GroupServerWorkerStore):
desc="update_group_profile",
)
- def update_attestation_renewal(self, group_id, user_id, attestation):
+ async def update_attestation_renewal(
+ self, group_id: str, user_id: str, attestation: dict
+ ) -> None:
"""Update an attestation that we have renewed
"""
- return self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
desc="update_attestation_renewal",
)
- def update_remote_attestion(self, group_id, user_id, attestation):
+ async def update_remote_attestion(
+ self, group_id: str, user_id: str, attestation: dict
+ ) -> None:
"""Update an attestation that a remote has renewed
"""
- return self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
@@ -1236,16 +1324,16 @@ class GroupServerStore(GroupServerWorkerStore):
desc="update_remote_attestion",
)
- def remove_attestation_renewal(self, group_id, user_id):
+ async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int:
"""Remove an attestation that we thought we should renew, but actually
shouldn't. Ideally this would never get called as we would never
incorrectly try and do attestations for local users on local groups.
Args:
- group_id (str)
- user_id (str)
+ group_id
+ user_id
"""
- return self.db_pool.simple_delete(
+ return await self.db_pool.simple_delete(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
desc="remove_attestation_renewal",
@@ -1254,14 +1342,11 @@ class GroupServerStore(GroupServerWorkerStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
- def delete_group(self, group_id):
+ async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database.
Args:
- group_id (str)
-
- Returns:
- Deferred
+ group_id: The group ID to delete.
"""
def _delete_group_txn(txn):
@@ -1285,4 +1370,4 @@ class GroupServerStore(GroupServerWorkerStore):
txn, table=table, keyvalues={"group_id": group_id}
)
- return self.db_pool.runInteraction("delete_group", _delete_group_txn)
+ await self.db_pool.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 384e9c5eb0..ad43bb05ab 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@
import itertools
import logging
+from typing import Dict, Iterable, List, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
@@ -41,16 +42,17 @@ class KeyStore(SQLBaseStore):
@cachedList(
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
)
- def get_server_verify_keys(self, server_name_and_key_ids):
+ async def get_server_verify_keys(
+ self, server_name_and_key_ids: Iterable[Tuple[str, str]]
+ ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
"""
Args:
- server_name_and_key_ids (iterable[Tuple[str, str]]):
+ server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
- Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
- map from (server_name, key_id) -> FetchKeyResult, or None if the key is
- unknown
+ A map from (server_name, key_id) -> FetchKeyResult, or None if the
+ key is unknown
"""
keys = {}
@@ -86,14 +88,19 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch)
return keys
- return self.db_pool.runInteraction("get_server_verify_keys", _txn)
+ return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
- def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ async def store_server_verify_keys(
+ self,
+ from_server: str,
+ ts_added_ms: int,
+ verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+ ) -> None:
"""Stores NACL verification keys for remote servers.
Args:
- from_server (str): Where the verification keys were looked up
- ts_added_ms (int): The time to record that the key was added
- verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ from_server: Where the verification keys were looked up
+ ts_added_ms: The time to record that the key was added
+ verify_keys:
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
@@ -115,13 +122,7 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
- def _invalidate(res):
- f = self._get_server_verify_key.invalidate
- for i in invalidations:
- f((i,))
- return res
-
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"store_server_verify_keys",
self.db_pool.simple_upsert_many_txn,
table="server_signature_keys",
@@ -134,24 +135,34 @@ class KeyStore(SQLBaseStore):
"verify_key",
),
value_values=value_values,
- ).addCallback(_invalidate)
+ )
- def store_server_keys_json(
- self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
- ):
+ invalidate = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ invalidate((i,))
+
+ async def store_server_keys_json(
+ self,
+ server_name: str,
+ key_id: str,
+ from_server: str,
+ ts_now_ms: int,
+ ts_expires_ms: int,
+ key_json_bytes: bytes,
+ ) -> None:
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
- server_name (str): The name of the server.
- key_id (str): The identifer of the key this JSON is for.
- from_server (str): The server this JSON was fetched from.
- ts_now_ms (int): The time now in milliseconds.
- ts_valid_until_ms (int): The time when this json stops being valid.
- key_json (bytes): The encoded JSON.
+ server_name: The name of the server.
+ key_id: The identifer of the key this JSON is for.
+ from_server: The server this JSON was fetched from.
+ ts_now_ms: The time now in milliseconds.
+ ts_valid_until_ms: The time when this json stops being valid.
+ key_json_bytes: The encoded JSON.
"""
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
@@ -169,7 +180,9 @@ class KeyStore(SQLBaseStore):
desc="store_server_keys_json",
)
- def get_server_keys_json(self, server_keys):
+ async def get_server_keys_json(
+ self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
+ ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
"""Retrive the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
@@ -178,8 +191,7 @@ class KeyStore(SQLBaseStore):
Args:
server_keys (list): List of (server_name, key_id, source) triplets.
Returns:
- Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
- Dict mapping (server_name, key_id, source) triplets to lists of dicts
+ A mapping from (server_name, key_id, source) triplets to a list of dicts
"""
def _get_server_keys_json_txn(txn):
@@ -205,6 +217,6 @@ class KeyStore(SQLBaseStore):
results[(server_name, key_id, from_server)] = rows
return results
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 80fc1cd009..86557d5512 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
- def get_local_media(self, media_id):
+ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
+
Returns:
None if the media_id doesn't exist.
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -57,7 +60,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_local_media",
)
- def store_local_media(
+ async def store_local_media(
self,
media_id,
media_type,
@@ -66,8 +69,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length,
user_id,
url_cache=None,
- ):
- return self.db_pool.simple_insert(
+ ) -> None:
+ await self.db_pool.simple_insert(
"local_media_repository",
{
"media_id": media_id,
@@ -81,16 +84,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_media",
)
- def mark_local_media_as_safe(self, media_id: str):
+ async def mark_local_media_as_safe(self, media_id: str) -> None:
"""Mark a local media as safe from quarantining."""
- return self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="local_media_repository",
keyvalues={"media_id": media_id},
updatevalues={"safe_from_quarantine": True},
desc="mark_local_media_as_safe",
)
- def get_url_cache(self, url, ts):
+ async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
@@ -136,12 +139,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
)
- return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
+ return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
- def store_url_cache(
+ async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
@@ -155,8 +158,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_url_cache",
)
- def get_local_media_thumbnails(self, media_id):
- return self.db_pool.simple_select_list(
+ async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
+ return await self.db_pool.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@@ -169,7 +172,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_local_media_thumbnails",
)
- def store_local_thumbnail(
+ async def store_local_thumbnail(
self,
media_id,
thumbnail_width,
@@ -178,7 +181,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_thumbnail",
)
- def get_cached_remote_media(self, origin, media_id):
- return self.db_pool.simple_select_one(
+ async def get_cached_remote_media(
+ self, origin, media_id: str
+ ) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -207,7 +212,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_cached_remote_media",
)
- def store_cached_remote_media(
+ async def store_cached_remote_media(
self,
origin,
media_id,
@@ -217,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name,
filesystem_id,
):
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
@@ -232,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ async def update_cached_last_access_time(
+ self,
+ local_media: Iterable[str],
+ remote_media: Iterable[Tuple[str, str]],
+ time_ms: int,
+ ):
"""Updates the last access time of the given media
Args:
- local_media (iterable[str]): Set of media_ids
- remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ local_media: Set of media_ids
+ remote_media: Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
@@ -262,12 +272,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
- def get_remote_media_thumbnails(self, origin, media_id):
- return self.db_pool.simple_select_list(
+ async def get_remote_media_thumbnails(
+ self, origin: str, media_id: str
+ ) -> List[Dict[str, Any]]:
+ return await self.db_pool.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@@ -281,7 +293,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_remote_media_thumbnails",
)
- def store_remote_media_thumbnail(
+ async def store_remote_media_thumbnail(
self,
origin,
media_id,
@@ -292,7 +304,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
@@ -307,18 +319,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail",
)
- def get_remote_media_before(self, before_ts):
+ async def get_remote_media_before(self, before_ts):
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
- return self.db_pool.execute(
+ return await self.db_pool.execute(
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
)
- def delete_remote_media(self, media_origin, media_id):
+ async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
def delete_remote_media_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@@ -331,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_remote_media", delete_remote_media_txn
)
- def get_expired_url_cache(self, now_ts):
+ async def get_expired_url_cache(self, now_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository_url_cache"
" WHERE expires_ts < ?"
@@ -347,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)
@@ -364,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"delete_url_cache", _delete_url_cache_txn
)
- def get_url_cache_media_before(self, before_ts):
+ async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository"
" WHERE created_ts < ? AND url_cache IS NOT NULL"
@@ -376,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e71cdd2cb4..1d793d3deb 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import Dict, List
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
@@ -33,11 +33,11 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self.hs = hs
@cached(num_args=0)
- def get_monthly_active_count(self):
+ async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
Returns:
- Defered[int]: Number of current monthly active users
+ Number of current monthly active users
"""
def _count_users(txn):
@@ -46,10 +46,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
- return self.db_pool.runInteraction("count_users", _count_users)
+ return await self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0)
- def get_monthly_active_count_by_service(self):
+ async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
@@ -57,8 +57,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
method to return anything other than native matrix users.
Returns:
- Deferred[dict]: dict that includes a mapping between app_service_id
- and the number of occurrences.
+ A mapping between app_service_id and the number of occurrences.
"""
@@ -74,7 +73,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
result = txn.fetchall()
return dict(result)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_users_by_service", _count_users_by_service
)
@@ -99,17 +98,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users
@cached(num_args=1)
- def user_last_seen_monthly_active(self, user_id):
+ async def user_last_seen_monthly_active(self, user_id: str) -> int:
"""
- Checks if a given user is part of the monthly active user group
- Arguments:
- user_id (str): user to add/update
- Return:
- Deferred[int] : timestamp since last seen, None if never seen
+ Checks if a given user is part of the monthly active user group
+ Arguments:
+ user_id: user to add/update
+
+ Return:
+ Timestamp since last seen, None if never seen
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index dcd1ff911a..2aac64901b 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -1,9 +1,13 @@
+from typing import Optional
+
from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
- def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
- return self.db_pool.simple_insert(
+ async def insert_open_id_token(
+ self, token: str, ts_valid_until_ms: int, user_id: str
+ ) -> None:
+ await self.db_pool.simple_insert(
table="open_id_tokens",
values={
"token": token,
@@ -13,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
desc="insert_open_id_token",
)
- def get_user_id_for_open_id_token(self, token, ts_now_ms):
+ async def get_user_id_for_open_id_token(
+ self, token: str, ts_now_ms: int
+ ) -> Optional[str]:
def get_user_id_for_token_txn(txn):
sql = (
"SELECT user_id FROM open_id_tokens"
@@ -28,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 59ba12820a..c9f655dfb7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -15,15 +15,15 @@
from typing import List, Tuple
+from synapse.api.presence import UserPresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = await self._presence_id_gen.get_next_mult(
len(presence_states)
)
@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_presence_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
)
- def get_presence_for_users(self, user_ids):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_presence_for_users(self, user_ids):
+ rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
-
- def allow_presence_visible(self, observed_localpart, observer_userid):
- return self.db_pool.simple_insert(
- table="presence_allow_inbound",
- values={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="allow_presence_visible",
- or_ignore=True,
- )
-
- def disallow_presence_visible(self, observed_localpart, observer_userid):
- return self.db_pool.simple_delete_one(
- table="presence_allow_inbound",
- keyvalues={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="disallow_presence_visible",
- )
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8261357d4..d2e0685e9e 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore):
- async def get_profileinfo(self, user_localpart):
+ async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
- def get_profile_displayname(self, user_localpart):
- return self.db_pool.simple_select_one_onecol(
+ async def get_profile_displayname(self, user_localpart: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
- def get_profile_avatar_url(self, user_localpart):
- return self.db_pool.simple_select_one_onecol(
+ async def get_profile_avatar_url(self, user_localpart: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
- def get_from_remote_profile_cache(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_from_remote_profile_cache(
+ self, user_id: str
+ ) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@@ -63,21 +66,25 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
- def create_profile(self, user_localpart):
- return self.db_pool.simple_insert(
+ async def create_profile(self, user_localpart: str) -> None:
+ await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self.db_pool.simple_update_one(
+ async def set_profile_displayname(
+ self, user_localpart: str, new_displayname: str
+ ) -> None:
+ await self.db_pool.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
desc="set_profile_displayname",
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self.db_pool.simple_update_one(
+ async def set_profile_avatar_url(
+ self, user_localpart: str, new_avatar_url: str
+ ) -> None:
+ await self.db_pool.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
@@ -86,13 +93,15 @@ class ProfileWorkerStore(SQLBaseStore):
class ProfileStore(ProfileWorkerStore):
- def add_remote_profile_cache(self, user_id, displayname, avatar_url):
+ async def add_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> None:
"""Ensure we are caching the remote user's profiles.
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -103,8 +112,10 @@ class ProfileStore(ProfileWorkerStore):
desc="add_remote_profile_cache",
)
- def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self.db_pool.simple_update(
+ async def update_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> int:
+ return await self.db_pool.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
updatevalues={
@@ -127,7 +138,9 @@ class ProfileStore(ProfileWorkerStore):
desc="delete_remote_profile_cache",
)
- def get_remote_profile_cache_entries_that_expire(self, last_checked):
+ async def get_remote_profile_cache_entries_that_expire(
+ self, last_checked: int
+ ) -> Dict[str, str]:
"""Get all users who haven't been checked since `last_checked`
"""
@@ -142,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db_pool.cursor_to_dict(txn)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3526b6fd66..ea833829ae 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Tuple
+from typing import Any, List, Set, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
@@ -25,25 +25,24 @@ logger = logging.getLogger(__name__)
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
- def purge_history(self, room_id, token, delete_local_events):
+ async def purge_history(
+ self, room_id: str, token: str, delete_local_events: bool
+ ) -> Set[int]:
"""Deletes room history before a certain point
Args:
- room_id (str):
-
- token (str): A topological token to delete events before
-
- delete_local_events (bool):
+ room_id:
+ token: A topological token to delete events before
+ delete_local_events:
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
Returns:
- Deferred[set[int]]: The set of state groups that are referenced by
- deleted events.
+ The set of state groups that are referenced by deleted events.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
@@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
return referenced_state_groups
- def purge_room(self, room_id):
+ async def purge_room(self, room_id: str) -> List[int]:
"""Deletes all record of a room
Args:
- room_id (str)
+ room_id
Returns:
- Deferred[List[int]]: The list of state groups to delete.
+ The list of state groups to delete.
"""
-
- return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
+ return await self.db_pool.runInteraction(
+ "purge_room", self._purge_room_txn, room_id
+ )
def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 6562db5c2b..0de802a86b 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -18,8 +18,6 @@ import abc
import logging
from typing import List, Tuple, Union
-from twisted.internet import defer
-
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -30,9 +28,9 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -82,9 +80,9 @@ class PushRulesWorkerStore(
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- ) # type: Union[ChainedIdGenerator, SlavedIdTracker]
+ self._push_rules_stream_id_gen = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ ) # type: Union[StreamIdGenerator, SlavedIdTracker]
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
@@ -115,9 +113,9 @@ class PushRulesWorkerStore(
"""
raise NotImplementedError()
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_for_user(self, user_id):
- rows = yield self.db_pool.simple_select_list(
+ @cached(max_entries=5000)
+ async def get_push_rules_for_user(self, user_id):
+ rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@@ -133,17 +131,15 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
- enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+ enabled_map = await self.get_push_rules_enabled_for_user(user_id)
use_new_defaults = user_id in self._users_new_default_push_rules
- rules = _load_rules(rows, enabled_map, use_new_defaults)
+ return _load_rules(rows, enabled_map, use_new_defaults)
- return rules
-
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_enabled_for_user(self, user_id):
- results = yield self.db_pool.simple_select_list(
+ @cached(max_entries=5000)
+ async def get_push_rules_enabled_for_user(self, user_id):
+ results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@@ -151,9 +147,11 @@ class PushRulesWorkerStore(
)
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
- def have_push_rules_changed_for_user(self, user_id, last_id):
+ async def have_push_rules_changed_for_user(
+ self, user_id: str, last_id: int
+ ) -> bool:
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
+ return False
else:
def have_push_rules_changed_txn(txn):
@@ -165,23 +163,20 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone()
return bool(count)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
@cachedList(
- cached_method_name="get_push_rules_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
)
- def bulk_get_push_rules(self, user_ids):
+ async def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}
results = {user_id: [] for user_id in user_ids}
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -194,7 +189,7 @@ class PushRulesWorkerStore(
for row in rows:
results.setdefault(row["user_name"], []).append(row)
- enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+ enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
use_new_defaults = user_id in self._users_new_default_push_rules
@@ -205,14 +200,15 @@ class PushRulesWorkerStore(
return results
- @defer.inlineCallbacks
- def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+ async def copy_push_rule_from_room_to_room(
+ self, new_room_id: str, user_id: str, rule: dict
+ ) -> None:
"""Copy a single push rule from one room to another for a specific user.
Args:
- new_room_id (str): ID of the new room.
- user_id (str): ID of user the push rule belongs to.
- rule (Dict): A push rule.
+ new_room_id: ID of the new room.
+ user_id : ID of user the push rule belongs to.
+ rule: A push rule.
"""
# Create new rule id
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
@@ -224,7 +220,7 @@ class PushRulesWorkerStore(
condition["pattern"] = new_room_id
# Add the rule for the new room
- yield self.add_push_rule(
+ await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
priority_class=rule["priority_class"],
@@ -232,20 +228,19 @@ class PushRulesWorkerStore(
actions=rule["actions"],
)
- @defer.inlineCallbacks
- def copy_push_rules_from_room_to_room_for_user(
- self, old_room_id, new_room_id, user_id
- ):
+ async def copy_push_rules_from_room_to_room_for_user(
+ self, old_room_id: str, new_room_id: str, user_id: str
+ ) -> None:
"""Copy all of the push rules from one room to another for a specific
user.
Args:
- old_room_id (str): ID of the old room.
- new_room_id (str): ID of the new room.
- user_id (str): ID of user to copy push rules for.
+ old_room_id: ID of the old room.
+ new_room_id: ID of the new room.
+ user_id: ID of user to copy push rules for.
"""
# Retrieve push rules for this user
- user_push_rules = yield self.get_push_rules_for_user(user_id)
+ user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
@@ -254,21 +249,20 @@ class PushRulesWorkerStore(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
- yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
+ await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
- inlineCallbacks=True,
)
- def bulk_get_push_rules_enabled(self, user_ids):
+ async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}
results = {user_id: {} for user_id in user_ids}
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@@ -332,8 +326,7 @@ class PushRulesWorkerStore(
class PushRuleStore(PushRulesWorkerStore):
- @defer.inlineCallbacks
- def add_push_rule(
+ async def add_push_rule(
self,
user_id,
rule_id,
@@ -342,13 +335,14 @@ class PushRuleStore(PushRulesWorkerStore):
actions,
before=None,
after=None,
- ):
+ ) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
if before or after:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
@@ -362,7 +356,7 @@ class PushRuleStore(PushRulesWorkerStore):
after,
)
else:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
@@ -546,16 +540,15 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
- @defer.inlineCallbacks
- def delete_push_rule(self, user_id, rule_id):
+ async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
- user_id (str): The matrix ID of the push rule owner
- rule_id (str): The rule_id of the rule to be deleted
+ user_id: The matrix ID of the push rule owner
+ rule_id: The rule_id of the rule to be deleted
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
@@ -567,20 +560,21 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
event_stream_ordering,
)
- @defer.inlineCallbacks
- def set_push_rule_enabled(self, user_id, rule_id, enabled):
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
@@ -611,8 +605,9 @@ class PushRuleStore(PushRulesWorkerStore):
op="ENABLE" if enabled else "DISABLE",
)
- @defer.inlineCallbacks
- def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+ async def set_push_rule_actions(
+ self, user_id, rule_id, actions, is_default_rule
+ ) -> None:
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -651,9 +646,10 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
@@ -681,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_push_rules_stream_token(self):
- """Get the position of the push rules stream.
- Returns a pair of a stream id for the push_rules stream and the
- room stream ordering it corresponds to."""
- return self._push_rules_stream_id_gen.get_current_token()
-
def get_max_push_rules_stream_id(self):
- return self.get_push_rules_stream_token()[0]
+ return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b5200fbe79..c388468273 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -19,10 +19,8 @@ from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
logger = logging.getLogger(__name__)
@@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore):
Drops any rows whose data cannot be decoded
"""
for r in rows:
- dataJson = r["data"]
+ data_json = r["data"]
try:
- r["data"] = db_to_json(dataJson)
+ r["data"] = db_to_json(data_json)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
- dataJson,
+ data_json,
e.args[0],
)
continue
yield r
- @defer.inlineCallbacks
- def user_has_pusher(self, user_id):
- ret = yield self.db_pool.simple_select_one_onecol(
+ async def user_has_pusher(self, user_id):
+ ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_by_user_id(self, user_id):
return self.get_pushers_by({"user_name": user_id})
- @defer.inlineCallbacks
- def get_pushers_by(self, keyvalues):
- ret = yield self.db_pool.simple_select_list(
+ async def get_pushers_by(self, keyvalues):
+ ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
[
@@ -87,16 +83,14 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
- @defer.inlineCallbacks
- def get_all_pushers(self):
+ async def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
- rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
- return rows
+ return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -164,19 +158,16 @@ class PusherWorkerStore(SQLBaseStore):
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
- @cachedInlineCallbacks(num_args=1, max_entries=15000)
- def get_if_user_has_pusher(self, user_id):
+ @cached(num_args=1, max_entries=15000)
+ async def get_if_user_has_pusher(self, user_id):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
- cached_method_name="get_if_user_has_pusher",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
- def get_if_users_have_pushers(self, user_ids):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_if_users_have_pushers(self, user_ids):
+ rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@@ -189,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore):
return result
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering(
+ async def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
- ):
- yield self.db_pool.simple_update_one(
+ ) -> None:
+ await self.db_pool.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering_and_success(
- self, app_id, pushkey, user_id, last_stream_ordering, last_success
- ):
+ async def update_pusher_last_stream_ordering_and_success(
+ self,
+ app_id: str,
+ pushkey: str,
+ user_id: str,
+ last_stream_ordering: int,
+ last_success: int,
+ ) -> bool:
"""Update the last stream ordering position we've processed up to for
the given pusher.
Args:
- app_id (str)
- pushkey (str)
- last_stream_ordering (int)
- last_success (int)
+ app_id
+ pushkey
+ user_id
+ last_stream_ordering
+ last_success
Returns:
- Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+ True if the pusher still exists; False if it has been deleted.
"""
- updated = yield self.db_pool.simple_update(
+ updated = await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@@ -228,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
- @defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self.db_pool.simple_update(
+ async def update_pusher_failing_since(
+ self, app_id, pushkey, user_id, failing_since
+ ) -> None:
+ await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
desc="update_pusher_failing_since",
)
- @defer.inlineCallbacks
- def get_throttle_params_by_room(self, pusher_id):
- res = yield self.db_pool.simple_select_list(
+ async def get_throttle_params_by_room(self, pusher_id):
+ res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@@ -255,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room
- @defer.inlineCallbacks
- def set_throttle_params(self, pusher_id, room_id, params):
+ async def set_throttle_params(self, pusher_id, room_id, params) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
- yield self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
@@ -272,8 +266,7 @@ class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_pusher(
+ async def add_pusher(
self,
user_id,
access_token,
@@ -287,11 +280,11 @@ class PusherStore(PusherWorkerStore):
data,
last_stream_ordering,
profile_tag="",
- ):
- with self._pushers_id_gen.get_next() as stream_id:
+ ) -> None:
+ with await self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
- yield self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -316,15 +309,16 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
(user_id,),
)
- @defer.inlineCallbacks
- def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
+ async def delete_pusher_by_app_id_pushkey_user_id(
+ self, app_id, pushkey, user_id
+ ) -> None:
def delete_pusher_txn(txn, stream_id):
self._invalidate_cache_and_stream(
txn, self.get_if_user_has_pusher, (user_id,)
@@ -350,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with self._pushers_id_gen.get_next() as stream_id:
- yield self.db_pool.runInteraction(
+ with await self._pushers_id_gen.get_next() as stream_id:
+ await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1920a8a152..4a0d5a320e 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@
import abc
import logging
-from typing import List, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -56,14 +56,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
raise NotImplementedError()
- @cachedInlineCallbacks()
- def get_users_with_read_receipts_in_room(self, room_id):
- receipts = yield self.get_receipts_for_room(room_id, "m.read")
+ @cached()
+ async def get_users_with_read_receipts_in_room(self, room_id):
+ receipts = await self.get_receipts_for_room(room_id, "m.read")
return {r["user_id"] for r in receipts}
@cached(num_args=2)
- def get_receipts_for_room(self, room_id, receipt_type):
- return self.db_pool.simple_select_list(
+ async def get_receipts_for_room(
+ self, room_id: str, receipt_type: str
+ ) -> List[Dict[str, Any]]:
+ return await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
@@ -71,8 +73,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=3)
- def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
- return self.db_pool.simple_select_one_onecol(
+ async def get_last_receipt_event_id_for_user(
+ self, user_id: str, room_id: str, receipt_type: str
+ ) -> Optional[str]:
+ return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@@ -84,9 +88,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
- @cachedInlineCallbacks(num_args=2)
- def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self.db_pool.simple_select_list(
+ @cached(num_args=2)
+ async def get_receipts_for_user(self, user_id, receipt_type):
+ rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -95,8 +99,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
- @defer.inlineCallbacks
- def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
+ async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
@@ -110,7 +113,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.db_pool.runInteraction(
+ rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
)
return {
@@ -122,56 +125,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows
}
- @defer.inlineCallbacks
- def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def get_linearized_receipts_for_rooms(
+ self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
- room_ids (list): List of room_ids.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_id: List of room_ids.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- list: A list of receipts.
+ A list of receipts.
"""
room_ids = set(room_ids)
if from_key is not None:
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
- room_ids = yield self._receipts_stream_cache.get_entities_changed(
+ room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
- results = yield self._get_linearized_receipts_for_rooms(
+ results = await self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key
)
return [ev for res in results.values() for ev in res]
- def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ async def get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for a single room for sending to clients.
Args:
- room_ids (str): The room id.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_ids: The room id.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- Deferred[list]: A list of receipts.
+ A list of receipts.
"""
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
- defer.succeed([])
+ return []
- return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+ return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
- @cachedInlineCallbacks(num_args=3, tree=True)
- def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ @cached(num_args=3, tree=True)
+ async def _get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""See get_linearized_receipts_for_room
"""
@@ -195,7 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rows
- rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
+ rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@@ -212,9 +220,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
- inlineCallbacks=True,
)
- def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}
@@ -243,7 +250,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
- txn_results = yield self.db_pool.runInteraction(
+ txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
@@ -269,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
return results
- def get_users_sent_receipts_between(self, last_id: int, current_id: int):
+ async def get_users_sent_receipts_between(
+ self, last_id: int, current_id: int
+ ) -> List[str]:
"""Get all users who sent receipts between `last_id` exclusive and
`current_id` inclusive.
Returns:
- Deferred[List[str]]
+ The list of users.
"""
if last_id == current_id:
@@ -289,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [r[0] for r in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
)
@@ -346,7 +355,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _invalidate_get_users_with_receipts_in_room(
- self, room_id, receipt_type, user_id
+ self, room_id: str, receipt_type: str, user_id: str
):
if receipt_type != "m.read":
return
@@ -472,15 +481,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
return rx_ts
- @defer.inlineCallbacks
- def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+ async def insert_receipt(
+ self,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: dict,
+ ) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
- return
+ return None
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
@@ -507,13 +522,12 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
- linearized_event_id = yield self.db_pool.runInteraction(
+ linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
)
- stream_id_manager = self._receipts_id_gen.get_next()
- with stream_id_manager as stream_id:
- event_ts = yield self.db_pool.runInteraction(
+ with await self._receipts_id_gen.get_next() as stream_id:
+ event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -535,14 +549,16 @@ class ReceiptsStore(ReceiptsWorkerStore):
now - event_ts,
)
- yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
+ await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
return stream_id, max_persisted_id
- def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
- return self.db_pool.runInteraction(
+ async def insert_graph_receipt(
+ self, room_id, receipt_type, user_id, event_ids, data
+ ):
+ return await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 402ae25571..01f20c03c2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,7 @@
import logging
import re
-from typing import Dict, List, Optional
-
-from twisted.internet.defer import Deferred
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -48,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cached()
- def get_user_by_id(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -86,22 +84,22 @@ class RegistrationWorkerStore(SQLBaseStore):
return is_trial
@cached()
- def get_user_by_access_token(self, token):
+ async def get_user_by_access_token(self, token: str) -> Optional[dict]:
"""Get a user from the given access token.
Args:
- token (str): The access token of a user.
+ token: The access token of a user.
Returns:
- defer.Deferred: None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`,
- `valid_until_ms`.
+ None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`,
+ `valid_until_ms`.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@cached()
- async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
+ async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]:
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
@@ -283,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
return bool(res) if res else False
- def set_server_admin(self, user, admin):
+ async def set_server_admin(self, user: UserID, admin: bool) -> None:
"""Sets whether a user is an admin of this homeserver.
Args:
- user (UserID): user ID of the user to test
- admin (bool): true iff the user is to be a server admin,
- false otherwise.
+ user: user ID of the user to test
+ admin: true iff the user is to be a server admin, false otherwise.
"""
def set_server_admin_txn(txn):
@@ -300,11 +297,11 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_user_by_id, (user.to_string(),)
)
- return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+ await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
@@ -366,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return True if res == UserTypes.SUPPORT else False
- def get_users_by_id_case_insensitive(self, user_id):
+ async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
"""Gets users that match user_id case insensitively.
- Returns a mapping of user_id -> password_hash.
+
+ Returns:
+ A mapping of user_id -> password_hash.
"""
def f(txn):
@@ -376,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
- return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
+ return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@@ -410,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("count_users", _count_users)
- def count_daily_user_type(self):
+ async def count_daily_user_type(self) -> Dict[str, int]:
"""
Counts 1) native non guest users
2) native guests users
@@ -439,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_daily_user_type", _count_daily_user_type
)
@@ -531,43 +530,42 @@ class RegistrationWorkerStore(SQLBaseStore):
"user_get_threepids",
)
- def user_delete_threepid(self, user_id, medium, address):
- return self.db_pool.simple_delete(
+ async def user_delete_threepid(self, user_id, medium, address) -> None:
+ await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
desc="user_delete_threepid",
)
- def user_delete_threepids(self, user_id: str):
+ async def user_delete_threepids(self, user_id: str) -> None:
"""Delete all threepid this user has bound
Args:
user_id: The user id to delete all threepids of
"""
- return self.db_pool.simple_delete(
+ await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id},
desc="user_delete_threepids",
)
- def add_user_bound_threepid(self, user_id, medium, address, id_server):
+ async def add_user_bound_threepid(
+ self, user_id: str, medium: str, address: str, id_server: str
+ ):
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.
Args:
- user_id (str)
- medium (str)
- address (str)
- id_server (str)
-
- Returns:
- Deferred
+ user_id
+ medium
+ address
+ id_server
"""
# We need to use an upsert, in case they user had already bound the
# threepid
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -580,41 +578,40 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="add_user_bound_threepid",
)
- def user_get_bound_threepids(self, user_id):
+ async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
"""Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids.
Args:
- user_id (str): The ID of the user to retrieve threepids for
+ user_id: The ID of the user to retrieve threepids for
Returns:
- Deferred[list[dict]]: List of dictionaries containing the following:
+ List of dictionaries containing the following keys:
medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com")
"""
- return self.db_pool.simple_select_list(
+ return await self.db_pool.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
desc="user_get_bound_threepids",
)
- def remove_user_bound_threepid(self, user_id, medium, address, id_server):
+ async def remove_user_bound_threepid(
+ self, user_id: str, medium: str, address: str, id_server: str
+ ) -> None:
"""The server proxied an unbind request to the given identity server on
behalf of the given user, so we remove the mapping of threepid to
identity server.
Args:
- user_id (str)
- medium (str)
- address (str)
- id_server (str)
-
- Returns:
- Deferred
+ user_id
+ medium
+ address
+ id_server
"""
- return self.db_pool.simple_delete(
+ await self.db_pool.simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -625,19 +622,21 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="remove_user_bound_threepid",
)
- def get_id_servers_user_bound(self, user_id, medium, address):
+ async def get_id_servers_user_bound(
+ self, user_id: str, medium: str, address: str
+ ) -> List[str]:
"""Get the list of identity servers that the server proxied bind
requests to for given user and threepid
Args:
- user_id (str)
- medium (str)
- address (str)
+ user_id: The user to query for identity servers.
+ medium: The medium to query for identity servers.
+ address: The address to query for identity servers.
Returns:
- Deferred[list[str]]: Resolves to a list of identity servers
+ A list of identity servers
"""
- return self.db_pool.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
@@ -665,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
# Convert the integer into a boolean.
return res == 1
- def get_threepid_validation_session(
- self, medium, client_secret, address=None, sid=None, validated=True
- ):
+ async def get_threepid_validation_session(
+ self,
+ medium: Optional[str],
+ client_secret: str,
+ address: Optional[str] = None,
+ sid: Optional[str] = None,
+ validated: Optional[bool] = True,
+ ) -> Optional[Dict[str, Any]]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
Args:
- medium (str|None): The medium of the 3PID
- address (str|None): The address of the 3PID
- sid (str|None): The ID of the validation session
- client_secret (str): A unique string provided by the client to help identify this
+ medium: The medium of the 3PID
+ client_secret: A unique string provided by the client to help identify this
validation attempt
- validated (bool|None): Whether sessions should be filtered by
+ address: The address of the 3PID
+ sid: The ID of the validation session
+ validated: Whether sessions should be filtered by
whether they have been validated already or not. None to
perform no filtering
Returns:
- Deferred[dict|None]: A dict containing the following:
+ A dict containing the following:
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
@@ -728,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
- def delete_threepid_session(self, session_id):
+ async def delete_threepid_session(self, session_id: str) -> None:
"""Removes a threepid validation session from the database. This can
be done after validation has been performed and whatever action was
waiting on it has been carried out
Args:
- session_id (str): The ID of the session to delete
+ session_id: The ID of the session to delete
"""
def delete_threepid_session_txn(txn):
@@ -753,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
keyvalues={"session_id": session_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
@@ -891,6 +895,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
+ self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
if self._account_validity.enabled:
self._clock.call_later(
@@ -942,40 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
- def register_user(
+ async def register_user(
self,
- user_id,
- password_hash=None,
- was_guest=False,
- make_guest=False,
- appservice_id=None,
- create_profile_with_displayname=None,
- admin=False,
- user_type=None,
- ):
+ user_id: str,
+ password_hash: Optional[str] = None,
+ was_guest: bool = False,
+ make_guest: bool = False,
+ appservice_id: Optional[str] = None,
+ create_profile_with_displayname: Optional[str] = None,
+ admin: bool = False,
+ user_type: Optional[str] = None,
+ shadow_banned: bool = False,
+ ) -> None:
"""Attempts to register an account.
Args:
- user_id (str): The desired user ID to register.
- password_hash (str|None): Optional. The password hash for this user.
- was_guest (bool): Optional. Whether this is a guest account being
- upgraded to a non-guest account.
- make_guest (boolean): True if the the new user should be guest,
- false to add a regular user account.
- appservice_id (str): The ID of the appservice registering the user.
- create_profile_with_displayname (unicode): Optionally create a profile for
+ user_id: The desired user ID to register.
+ password_hash: Optional. The password hash for this user.
+ was_guest: Whether this is a guest account being upgraded to a
+ non-guest account.
+ make_guest: True if the the new user should be guest, false to add a
+ regular user account.
+ appservice_id: The ID of the appservice registering the user.
+ create_profile_with_displayname: Optionally create a profile for
the user, setting their displayname to the given value
- admin (boolean): is an admin user?
- user_type (str|None): type of user. One of the values from
- api.constants.UserTypes, or None for a normal user.
+ admin: is an admin user?
+ user_type: type of user. One of the values from api.constants.UserTypes,
+ or None for a normal user.
+ shadow_banned: Whether the user is shadow-banned, i.e. they may be
+ told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
-
- Returns:
- Deferred
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"register_user",
self._register_user,
user_id,
@@ -986,6 +991,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
+ shadow_banned,
)
def _register_user(
@@ -999,6 +1005,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
+ shadow_banned,
):
user_id_obj = UserID.from_string(user_id)
@@ -1028,6 +1035,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
+ "shadow_banned": shadow_banned,
},
)
else:
@@ -1042,6 +1050,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
+ "shadow_banned": shadow_banned,
},
)
@@ -1075,9 +1084,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- def record_user_external_id(
+ async def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
- ) -> Deferred:
+ ) -> None:
"""Record a mapping from an external user id to a mxid
Args:
@@ -1085,7 +1094,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
@@ -1095,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
- def user_set_password_hash(self, user_id, password_hash):
+ async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@@ -1108,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
- def user_set_consent_version(self, user_id, consent_version):
+ async def user_set_consent_version(
+ self, user_id: str, consent_version: str
+ ) -> None:
"""Updates the user table to record privacy policy consent
Args:
- user_id (str): full mxid of the user to update
- consent_version (str): version of the policy the user has consented
- to
+ user_id: full mxid of the user to update
+ consent_version: version of the policy the user has consented to
Raises:
StoreError(404) if user not found
@@ -1133,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db_pool.runInteraction("user_set_consent_version", f)
+ await self.db_pool.runInteraction("user_set_consent_version", f)
- def user_set_consent_server_notice_sent(self, user_id, consent_version):
+ async def user_set_consent_server_notice_sent(
+ self, user_id: str, consent_version: str
+ ) -> None:
"""Updates the user table to record that we have sent the user a server
notice about privacy policy consent
Args:
- user_id (str): full mxid of the user to update
- consent_version (str): version of the policy we have notified the
- user about
+ user_id: full mxid of the user to update
+ consent_version: version of the policy we have notified the user about
Raises:
StoreError(404) if user not found
@@ -1157,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
+ await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
- def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
+ async def user_delete_access_tokens(
+ self,
+ user_id: str,
+ except_token_id: Optional[str] = None,
+ device_id: Optional[str] = None,
+ ) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access tokens belonging to a user
Args:
- user_id (str): ID of user the tokens belong to
- except_token_id (str): list of access_tokens IDs which should
- *not* be deleted
- device_id (str|None): ID of device the tokens are associated with.
+ user_id: ID of user the tokens belong to
+ except_token_id: access_tokens ID which should *not* be deleted
+ device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
- defer.Deferred[list[str, int, str|None, int]]: a list of
- (token, token id, device id) for each of the deleted tokens
+ A tuple of (token, token id, device id) for each of the deleted tokens
"""
def f(txn):
@@ -1203,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
- return self.db_pool.runInteraction("user_delete_access_tokens", f)
+ return await self.db_pool.runInteraction("user_delete_access_tokens", f)
- def delete_access_token(self, access_token):
+ async def delete_access_token(self, access_token: str) -> None:
def f(txn):
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
@@ -1215,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
- return self.db_pool.runInteraction("delete_access_token", f)
+ await self.db_pool.runInteraction("delete_access_token", f)
@cached()
async def is_guest(self, user_id: str) -> bool:
@@ -1229,36 +1243,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return res if res else False
- def add_user_pending_deactivation(self, user_id):
+ async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
)
- def del_user_pending_deactivation(self, user_id):
+ async def del_user_pending_deactivation(self, user_id: str) -> None:
"""
Removes the given user to the table of users who need to be parted from all the
rooms they're in, effectively marking that user as fully deactivated.
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
- return self.db_pool.simple_delete(
+ await self.db_pool.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
)
- def get_user_pending_deactivation(self):
+ async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@@ -1266,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="get_users_pending_deactivation",
)
- def validate_threepid_session(self, session_id, client_secret, token, current_ts):
+ async def validate_threepid_session(
+ self, session_id: str, client_secret: str, token: str, current_ts: int
+ ) -> Optional[str]:
"""Attempt to validate a threepid session using a token
Args:
- session_id (str): The id of a validation session
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- token (str): A validation token
- current_ts (int): The current unix time in milliseconds. Used for
- checking token expiry status
+ session_id: The id of a validation session
+ client_secret: A unique string provided by the client to help identify
+ this validation attempt
+ token: A validation token
+ current_ts: The current unix time in milliseconds. Used for checking
+ token expiry status
Raises:
ThreepidValidationError: if a matching validation token was not found or has
expired
Returns:
- deferred str|None: A str representing a link to redirect the user
- to if there is one.
+ A str representing a link to redirect the user to if there is one.
"""
# Insert everything into a transaction in order to run atomically
@@ -1297,15 +1312,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
if not row:
- raise ThreepidValidationError(400, "Unknown session_id")
+ if self._ignore_unknown_session_error:
+ # If we need to inhibit the error caused by an incorrect session ID,
+ # use None as placeholder values for the client secret and the
+ # validation timestamp.
+ # It shouldn't be an issue because they're both only checked after
+ # the token check, which should fail. And if it doesn't for some
+ # reason, the next check is on the client secret, which is NOT NULL,
+ # so we don't have to worry about the client secret matching by
+ # accident.
+ row = {"client_secret": None, "validated_at": None}
+ else:
+ raise ThreepidValidationError(400, "Unknown session_id")
+
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
- if retrieved_client_secret != client_secret:
- raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id"
- )
-
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_token",
@@ -1321,6 +1343,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expires = row["expires"]
next_link = row["next_link"]
+ if retrieved_client_secret != client_secret:
+ raise ThreepidValidationError(
+ 400, "This client_secret does not match the provided session_id"
+ )
+
# If the session is already validated, no need to revalidate
if validated_at:
return next_link
@@ -1341,73 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
- def upsert_threepid_validation_session(
- self,
- medium,
- address,
- client_secret,
- send_attempt,
- session_id,
- validated_at=None,
- ):
- """Upsert a threepid validation session
- Args:
- medium (str): The medium of the 3PID
- address (str): The address of the 3PID
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- send_attempt (int): The latest send_attempt on this session
- session_id (str): The id of this validation session
- validated_at (int|None): The unix timestamp in milliseconds of
- when the session was marked as valid
- """
- insertion_values = {
- "medium": medium,
- "address": address,
- "client_secret": client_secret,
- }
-
- if validated_at:
- insertion_values["validated_at"] = validated_at
-
- return self.db_pool.simple_upsert(
- table="threepid_validation_session",
- keyvalues={"session_id": session_id},
- values={"last_send_attempt": send_attempt},
- insertion_values=insertion_values,
- desc="upsert_threepid_validation_session",
- )
-
- def start_or_continue_validation_session(
+ async def start_or_continue_validation_session(
self,
- medium,
- address,
- session_id,
- client_secret,
- send_attempt,
- next_link,
- token,
- token_expires,
- ):
+ medium: str,
+ address: str,
+ session_id: str,
+ client_secret: str,
+ send_attempt: int,
+ next_link: Optional[str],
+ token: str,
+ token_expires: int,
+ ) -> None:
"""Creates a new threepid validation session if it does not already
exist and associates a new validation token with it
Args:
- medium (str): The medium of the 3PID
- address (str): The address of the 3PID
- session_id (str): The id of this validation session
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- send_attempt (int): The latest send_attempt on this session
- next_link (str|None): The link to redirect the user to upon
- successful validation
- token (str): The validation token
- token_expires (int): The timestamp for which after the token
- will no longer be valid
+ medium: The medium of the 3PID
+ address: The address of the 3PID
+ session_id: The id of this validation session
+ client_secret: A unique string provided by the client to help
+ identify this validation attempt
+ send_attempt: The latest send_attempt on this session
+ next_link: The link to redirect the user to upon successful validation
+ token: The validation token
+ token_expires: The timestamp for which after the token will no
+ longer be valid
"""
def start_or_continue_validation_session_txn(txn):
@@ -1436,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
- def cull_expired_threepid_validation_tokens(self):
+ async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
def cull_expired_threepid_validation_tokens_txn(txn, ts):
@@ -1449,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
DELETE FROM threepid_validation_token WHERE
expires < ?
"""
- return txn.execute(sql, (ts,))
+ txn.execute(sql, (ts,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py
index cf9ba51205..1e361aaa9a 100644
--- a/synapse/storage/databases/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Optional
from synapse.storage._base import SQLBaseStore
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
- def get_rejection_reason(self, event_id):
- return self.db_pool.simple_select_one_onecol(
+ async def get_rejection_reason(self, event_id: str) -> Optional[str]:
+ return await self.db_pool.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index a9ceffc20e..5cd61547f7 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
- def get_relations_for_event(
+ async def get_relations_for_event(
self,
- event_id,
- relation_type=None,
- event_type=None,
- aggregation_key=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
+ event_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ aggregation_key: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[RelationPaginationToken] = None,
+ to_token: Optional[RelationPaginationToken] = None,
+ ) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
- event_id (str): Fetch events that relate to this event ID.
- relation_type (str|None): Only fetch events with this relation
- type, if given.
- event_type (str|None): Only fetch events with this event type, if
- given.
- aggregation_key (str|None): Only fetch events with this aggregation
- key, if given.
- limit (int): Only fetch the most recent `limit` events.
- direction (str): Whether to fetch the most recent first (`"b"`) or
- the oldest first (`"f"`).
- from_token (RelationPaginationToken|None): Fetch rows from the given
- token, or from the start if None.
- to_token (RelationPaginationToken|None): Fetch rows up to the given
- token, or up to the end if None.
+ event_id: Fetch events that relate to this event ID.
+ relation_type: Only fetch events with this relation type, if given.
+ event_type: Only fetch events with this event type, if given.
+ aggregation_key: Only fetch events with this aggregation key, if given.
+ limit: Only fetch the most recent `limit` events.
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`).
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- Deferred[PaginationChunk]: List of event IDs that match relations
- requested. The rows are of the form `{"event_id": "..."}`.
+ List of event IDs that match relations requested. The rows are of
+ the form `{"event_id": "..."}`.
"""
where_clause = ["relates_to_id = ?"]
@@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
@cached(tree=True)
- def get_aggregation_groups_for_event(
+ async def get_aggregation_groups_for_event(
self,
- event_id,
- event_type=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
+ event_id: str,
+ event_type: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[AggregationPaginationToken] = None,
+ to_token: Optional[AggregationPaginationToken] = None,
+ ) -> PaginationChunk:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
on an event.
Args:
- event_id (str): Fetch events that relate to this event ID.
- event_type (str|None): Only fetch events with this event type, if
- given.
- limit (int): Only fetch the `limit` groups.
- direction (str): Whether to fetch the highest count first (`"b"`) or
+ event_id: Fetch events that relate to this event ID.
+ event_type: Only fetch events with this event type, if given.
+ limit: Only fetch the `limit` groups.
+ direction: Whether to fetch the highest count first (`"b"`) or
the lowest count first (`"f"`).
- from_token (AggregationPaginationToken|None): Fetch rows from the
- given token, or from the start if None.
- to_token (AggregationPaginationToken|None): Fetch rows up to the
- given token, or up to the end if None.
-
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- Deferred[PaginationChunk]: List of groups of annotations that
- match. Each row is a dict with `type`, `key` and `count` fields.
+ List of groups of annotations that match. Each row is a dict with
+ `type`, `key` and `count` fields.
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
@@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
@@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
return await self.get_event(edit_id, allow_none=True)
- def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+ async def has_user_annotated_event(
+ self, parent_id: str, event_type: str, aggregation_key: str, sender: str
+ ) -> bool:
"""Check if a user has already annotated an event with the same key
(e.g. already liked an event).
Args:
- parent_id (str): The event being annotated
- event_type (str): The event type of the annotation
- aggregation_key (str): The aggregation key of the annotation
- sender (str): The sender of the annotation
+ parent_id: The event being annotated
+ event_type: The event type of the annotation
+ aggregation_key: The aggregation key of the annotation
+ sender: The sender of the annotation
Returns:
- Deferred[bool]
+ True if the event is already annotated.
"""
sql = """
@@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
return bool(txn.fetchone())
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f4008e6221..717df97301 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -21,24 +21,19 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from canonicaljson import json
-
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchStore
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-OpsLevel = collections.namedtuple(
- "OpsLevel", ("ban_level", "kick_level", "redact_level")
-)
-
RatelimitOverride = collections.namedtuple(
"RatelimitOverride", ("messages_per_second", "burst_count")
)
@@ -78,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
self.config = hs.config
- def get_room(self, room_id):
+ async def get_room(self, room_id: str) -> dict:
"""Retrieve a room.
Args:
- room_id (str): The ID of the room to retrieve.
+ room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@@ -94,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
allow_none=True,
)
- def get_room_with_stats(self, room_id: str):
+ async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve room with statistics.
Args:
@@ -126,25 +121,29 @@ class RoomWorkerStore(SQLBaseStore):
res["public"] = bool(res["public"])
return res
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
)
- def get_public_room_ids(self):
- return self.db_pool.simple_select_onecol(
+ async def get_public_room_ids(self) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
table="rooms",
keyvalues={"is_public": True},
retcol="room_id",
desc="get_public_room_ids",
)
- def count_public_rooms(self, network_tuple, ignore_non_federatable):
+ async def count_public_rooms(
+ self,
+ network_tuple: Optional[ThirdPartyInstanceID],
+ ignore_non_federatable: bool,
+ ) -> int:
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
Args:
- network_tuple (ThirdPartyInstanceID|None)
- ignore_non_federatable (bool): If true filters out non-federatable rooms
+ network_tuple
+ ignore_non_federatable: If true filters out non-federatable rooms
"""
def _count_public_rooms_txn(txn):
@@ -188,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args)
return txn.fetchone()[0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
)
@@ -335,8 +334,8 @@ class RoomWorkerStore(SQLBaseStore):
return ret_val
@cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_room_blocked(self, room_id: str) -> Optional[bool]:
+ return await self.db_pool.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",
@@ -591,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore):
return row
- def get_media_mxcs_in_room(self, room_id):
+ async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
- room_id (str)
+ room_id
Returns:
- The local and remote media as a lists of tuples where the key is
- the hostname and the value is the media ID.
+ The local and remote media as a lists of the media IDs.
"""
def _get_media_mxcs_in_room_txn(txn):
@@ -615,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
- def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+ async def quarantine_media_ids_in_room(
+ self, room_id: str, quarantined_by: str
+ ) -> int:
"""For a room loops through all events with media and quarantines
the associated media
"""
@@ -632,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -695,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- def quarantine_media_by_id(
+ async def quarantine_media_by_id(
self, server_name: str, media_id: str, quarantined_by: str,
- ):
+ ) -> int:
"""quarantines a single local or remote media id
Args:
@@ -716,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_id_txn
)
- def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+ async def quarantine_media_ids_by_user(
+ self, user_id: str, quarantined_by: str
+ ) -> int:
"""quarantines all local media associated with a single user
Args:
@@ -732,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
@@ -1134,7 +1136,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1201,7 +1203,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1281,7 +1283,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
@@ -1289,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
- def get_room_count(self):
- """Retrieve a list of all rooms
+ async def get_room_count(self) -> int:
+ """Retrieve the total number of rooms.
"""
def f(txn):
@@ -1299,13 +1301,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
- return self.db_pool.runInteraction("get_rooms", f)
+ return await self.db_pool.runInteraction("get_rooms", f)
- def add_event_report(
- self, room_id, event_id, user_id, reason, content, received_ts
- ):
+ async def add_event_report(
+ self,
+ room_id: str,
+ event_id: str,
+ user_id: str,
+ reason: str,
+ content: JsonDict,
+ received_ts: int,
+ ) -> None:
next_id = self._event_reports_id_gen.get_next()
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
table="event_reports",
values={
"id": next_id,
@@ -1314,7 +1322,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"event_id": event_id,
"user_id": user_id,
"reason": reason,
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
},
desc="add_event_report",
)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index b2fcfc9bfe..91a8b43da3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,9 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
- @defer.inlineCallbacks
- def _count_known_servers(self):
+ async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
- count = yield self.db_pool.runInteraction("get_known_servers", _transact)
+ count = await self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -155,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000, iterable=True)
- def get_users_in_room(self, room_id: str):
- return self.db_pool.runInteraction(
+ async def get_users_in_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
@@ -183,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn]
@cached(max_entries=100000)
- def get_room_summary(self, room_id: str):
+ async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id: The room ID to query
Returns:
- Deferred[dict[str, MemberSummary]:
- dict of membership states, pointing to a MemberSummary named tuple.
+ dict of membership states, pointing to a MemberSummary named tuple.
"""
def _get_room_summary_txn(txn):
@@ -264,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res
- return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
+ return await self.db_pool.runInteraction(
+ "get_room_summary", _get_room_summary_txn
+ )
@cached()
- def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
+ async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
"""Get all the rooms the *local* user is invited to.
Args:
user_id: The user ID.
Returns:
- A awaitable list of RoomsForUser.
+ A list of RoomsForUser.
"""
- return self.get_rooms_for_local_user_where_membership_is(
+ return await self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
@@ -300,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
- self, user_id: str, membership_list: List[str]
- ) -> Optional[List[RoomsForUser]]:
+ self, user_id: str, membership_list: Collection[str]
+ ) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -316,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
- return None
+ return []
rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",
@@ -360,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(max_entries=500000, iterable=True)
- def get_rooms_for_user_with_stream_ordering(self, user_id: str):
+ async def get_rooms_for_user_with_stream_ordering(
+ self, user_id: str
+ ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
@@ -370,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id
Returns:
- Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
- the rooms the user is in currently, along with the stream ordering
- of the most recent join for that user and room.
+ Returns the rooms the user is in currently, along with the stream
+ ordering of the most recent join for that user and room.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn,
user_id,
)
- def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
+ def _get_rooms_for_user_with_stream_ordering_txn(
+ self, txn, user_id: str
+ ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
@@ -407,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
- results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
-
- return results
+ return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
@@ -589,11 +588,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
- list_name="event_ids",
- inlineCallbacks=True,
+ cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
)
- def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -601,11 +598,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
- Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+ dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -716,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return count == 0
@cached()
- def get_forgotten_rooms_for_user(self, user_id: str):
+ async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
"""Gets all rooms the user has forgotten.
Args:
- user_id
+ user_id: The user ID to query the rooms of.
Returns:
- Deferred[set[str]]
+ The forgotten rooms.
"""
def _get_forgotten_rooms_for_user_txn(txn):
@@ -749,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,))
return {row[0] for row in txn if row[1] == 0}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@@ -772,13 +769,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
- def get_membership_from_event_ids(
+ async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
- return self.db_pool.simple_select_many_batch(
+ return await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@@ -978,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
- def forget(self, user_id: str, room_id: str):
+ async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -999,10 +996,10 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
- return self.db_pool.runInteraction("forget_membership", f)
+ await self.db_pool.runInteraction("forget_membership", f)
-class _JoinedHostsCache(object):
+class _JoinedHostsCache:
"""Cache for joined hosts in a room that is optimised to handle updates
via state deltas.
"""
diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
new file mode 100644
index 0000000000..4cc96a5341
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A table of the IP address and user-agent used to complete each step of a
+-- user-interactive authentication session.
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
+ session_id TEXT NOT NULL,
+ ip TEXT NOT NULL,
+ user_agent TEXT NOT NULL,
+ UNIQUE (session_id, ip, user_agent),
+ FOREIGN KEY (session_id)
+ REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
new file mode 100644
index 0000000000..260b009b48
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A shadow-banned user may be told that their requests succeeded when they were
+-- actually ignored.
+ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
new file mode 100644
index 0000000000..15421b99ac
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This table is no longer used.
+DROP TABLE IF EXISTS presence_allow_inbound;
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
new file mode 100644
index 0000000000..317fba8a5d
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're hijacking the push actions to store unread messages and unread counts (specified
+-- in MSC2654) because doing otherwise would result in either performance issues or
+-- reimplementing a consequent bit of the push actions.
+
+-- Add columns to event_push_actions and event_push_actions_staging to track unread
+-- messages and calculate unread counts.
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT;
+
+-- Add column to event_push_summary
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 7f8d1880e5..f01cf2fd02 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,9 +16,10 @@
import logging
import re
from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Set
from synapse.api.errors import SynapseError
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count,
}
- def _find_highlights_in_postgres(self, search_query, events):
+ async def _find_highlights_in_postgres(
+ self, search_query: str, events: List[EventBase]
+ ) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlight the matching parts.
Args:
- search_query (str)
- events (list): A list of events
+ search_query
+ events: A list of events
Returns:
- deferred : A set of strings.
+ A set of strings.
"""
def f(txn):
@@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words
- return self.db_pool.runInteraction("_find_highlights", f)
+ return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index be191dd870..c8c67953e4 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Iterable, List, Tuple
+
from unpaddedbase64 import encode_base64
from synapse.storage._base import SQLBaseStore
+from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
@@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
- def get_event_reference_hashes(self, event_ids):
+ async def get_event_reference_hashes(
+ self, event_ids: Iterable[str]
+ ) -> Dict[str, Dict[str, bytes]]:
+ """Get all hashes for given events.
+
+ Args:
+ event_ids: The event IDs to get hashes for.
+
+ Returns:
+ A mapping of event ID to a mapping of algorithm to hash.
+ """
+
def f(txn):
return {
event_id: self._get_event_reference_hashes_txn(txn, event_id)
for event_id in event_ids
}
- return self.db_pool.runInteraction("get_event_reference_hashes", f)
+ return await self.db_pool.runInteraction("get_event_reference_hashes", f)
- async def add_event_hashes(self, event_ids):
+ async def add_event_hashes(
+ self, event_ids: Iterable[str]
+ ) -> List[Tuple[str, Dict[str, str]]]:
+ """
+
+ Args:
+ event_ids: The event IDs
+
+ Returns:
+ A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
+ """
hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
@@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
return list(hashes.items())
- def _get_event_reference_hashes_txn(self, txn, event_id):
+ def _get_event_reference_hashes_txn(
+ self, txn: Cursor, event_id: str
+ ) -> Dict[str, bytes]:
"""Get all the hashes for a given PDU.
Args:
- txn (cursor):
- event_id (str): Id for the Event.
+ txn:
+ event_id: Id for the Event.
Returns:
- A dict[unicode, bytes] of algorithm -> hash.
+ A mapping of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 96e0378e50..5c6168e301 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
+from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event
@cached(max_entries=100000, iterable=True)
- def get_current_state_ids(self, room_id):
+ async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
Args:
- room_id (str)
+ room_id: The room to get the state IDs of.
Returns:
- deferred: dict of (type, state_key) -> event_id
+ The current state of the room.
"""
def _get_current_state_ids_txn(txn):
@@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
- def get_filtered_current_state_ids(
+ async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
@@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
from the database.
Returns:
- defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
+ Map from type/state_key to event ID.
"""
where_clause, where_args = state_filter.make_sql_filter_clause()
if not where_clause:
# We delegate to the cached version
- return self.get_current_state_ids(room_id)
+ return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn):
results = {}
@@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@@ -260,8 +261,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return event.content.get("canonical_alias")
@cached(max_entries=50000)
- def _get_state_group_for_event(self, event_id):
- return self.db_pool.simple_select_one_onecol(
+ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
+ return await self.db_pool.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
@@ -273,12 +274,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
- inlineCallbacks=True,
)
- def _get_state_group_for_events(self, event_ids):
+ async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 0d963c98ff..356623fc6e 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -14,8 +14,7 @@
# limitations under the License.
import logging
-
-from twisted.internet import defer
+from typing import Any, Dict, List, Tuple
from synapse.storage._base import SQLBaseStore
@@ -23,7 +22,9 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
- def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
+ async def get_current_state_deltas(
+ self, prev_stream_id: int, max_stream_id: int
+ ) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
@@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore):
if it's new state.
Args:
- prev_stream_id (int): point to get changes since (exclusive)
- max_stream_id (int): the point that we know has been correctly persisted
+ prev_stream_id: point to get changes since (exclusive)
+ max_stream_id: the point that we know has been correctly persisted
- ie, an upper limit to return changes from.
Returns:
- Deferred[tuple[int, list[dict]]: A tuple consisting of:
+ A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
@@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
- return defer.succeed((max_stream_id, []))
+ return (max_stream_id, [])
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
@@ -102,7 +103,7 @@ class StateDeltasStore(SQLBaseStore):
txn.execute(sql, (prev_stream_id, clipped_stream_id))
return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
@@ -114,8 +115,8 @@ class StateDeltasStore(SQLBaseStore):
retcol="COALESCE(MAX(stream_id), -1)",
)
- def get_max_stream_id_in_current_state_deltas(self):
- return self.db_pool.runInteraction(
+ async def get_max_stream_id_in_current_state_deltas(self):
+ return await self.db_pool.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 802c9019b9..55a250ef06 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,8 +15,9 @@
# limitations under the License.
import logging
+from collections import Counter
from itertools import chain
-from typing import Tuple
+from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.defer import DeferredLock
@@ -211,26 +212,44 @@ class StatsStore(StateDeltasStore):
return len(rooms_to_work_on)
- def get_stats_positions(self):
+ async def get_stats_positions(self) -> int:
"""
Returns the stats processor positions.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
desc="stats_incremental_position",
)
- def update_room_state(self, room_id, fields):
- """
+ async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
+ """Update the state of a room.
+
+ fields can contain the following keys with string values:
+ * join_rules
+ * history_visibility
+ * encryption
+ * name
+ * topic
+ * avatar
+ * canonical_alias
+
+ A is_federatable key can also be included with a boolean value.
+
Args:
- room_id (str)
- fields (dict[str:Any])
+ room_id: The room ID to update the state of.
+ fields: The fields to update. This can include a partial list of the
+ above fields to only update some room information.
"""
-
- # For whatever reason some of the fields may contain null bytes, which
- # postgres isn't a fan of, so we replace those fields with null.
+ # Ensure that the values to update are valid, they should be strings and
+ # not contain any null bytes.
+ #
+ # Invalid data gets overwritten with null.
+ #
+ # Note that a missing value should not be overwritten (it keeps the
+ # previous value).
+ sentinel = object()
for col in (
"join_rules",
"history_visibility",
@@ -240,32 +259,34 @@ class StatsStore(StateDeltasStore):
"avatar",
"canonical_alias",
):
- field = fields.get(col)
- if field and "\0" in field:
+ field = fields.get(col, sentinel)
+ if field is not sentinel and (not isinstance(field, str) or "\0" in field):
fields[col] = None
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="room_stats_state",
keyvalues={"room_id": room_id},
values=fields,
desc="update_room_state",
)
- def get_statistics_for_subject(self, stats_type, stats_id, start, size=100):
+ async def get_statistics_for_subject(
+ self, stats_type: str, stats_id: str, start: str, size: int = 100
+ ) -> List[dict]:
"""
Get statistics for a given subject.
Args:
- stats_type (str): The type of subject
- stats_id (str): The ID of the subject (e.g. room_id or user_id)
- start (int): Pagination start. Number of entries, not timestamp.
- size (int): How many entries to return.
+ stats_type: The type of subject
+ stats_id: The ID of the subject (e.g. room_id or user_id)
+ start: Pagination start. Number of entries, not timestamp.
+ size: How many entries to return.
Returns:
- Deferred[list[dict]], where the dict has the keys of
+ A list of dicts, where the dict has the keys of
ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_statistics_for_subject",
self._get_statistics_for_subject_txn,
stats_type,
@@ -300,7 +321,7 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
- def get_earliest_token_for_stats(self, stats_type, id):
+ async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -308,29 +329,28 @@ class StatsStore(StateDeltasStore):
being calculated.
Returns:
- Deferred[int]
+ The earliest token.
"""
table, id_col = TYPE_TO_TABLE[stats_type]
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
allow_none=True,
)
- def bulk_update_stats_delta(self, ts, updates, stream_id):
+ async def bulk_update_stats_delta(
+ self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+ ) -> None:
"""Bulk update stats tables for a given stream_id and updates the stats
incremental position.
Args:
- ts (int): Current timestamp in ms
- updates(dict[str, dict[str, dict[str, Counter]]]): The updates to
- commit as a mapping stats_type -> stats_id -> field -> delta.
- stream_id (int): Current position.
-
- Returns:
- Deferred
+ ts: Current timestamp in ms
+ updates: The updates to commit as a mapping of
+ stats_type -> stats_id -> field -> delta.
+ stream_id: Current position.
"""
def _bulk_update_stats_delta_txn(txn):
@@ -355,38 +375,37 @@ class StatsStore(StateDeltasStore):
updatevalues={"stream_id": stream_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"bulk_update_stats_delta", _bulk_update_stats_delta_txn
)
- def update_stats_delta(
+ async def update_stats_delta(
self,
- ts,
- stats_type,
- stats_id,
- fields,
- complete_with_stream_id,
- absolute_field_overrides=None,
- ):
+ ts: int,
+ stats_type: str,
+ stats_id: str,
+ fields: Dict[str, int],
+ complete_with_stream_id: Optional[int],
+ absolute_field_overrides: Optional[Dict[str, int]] = None,
+ ) -> None:
"""
Updates the statistics for a subject, with a delta (difference/relative
change).
Args:
- ts (int): timestamp of the change
- stats_type (str): "room" or "user" – the kind of subject
- stats_id (str): the subject's ID (room ID or user ID)
- fields (dict[str, int]): Deltas of stats values.
- complete_with_stream_id (int, optional):
+ ts: timestamp of the change
+ stats_type: "room" or "user" – the kind of subject
+ stats_id: the subject's ID (room ID or user ID)
+ fields: Deltas of stats values.
+ complete_with_stream_id:
If supplied, converts an incomplete row into a complete row,
with the supplied stream_id marked as the stream_id where the
row was completed.
- absolute_field_overrides (dict[str, int]): Current stats values
- (i.e. not deltas) of absolute fields.
- Does not work with per-slice fields.
+ absolute_field_overrides: Current stats values (i.e. not deltas) of
+ absolute fields. Does not work with per-slice fields.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_stats_delta",
self._update_stats_delta_txn,
ts,
@@ -646,19 +665,20 @@ class StatsStore(StateDeltasStore):
txn, into_table, all_dest_keyvalues, src_row
)
- def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
+ async def get_changes_room_total_events_and_bytes(
+ self, min_pos: int, max_pos: int
+ ) -> Dict[str, Dict[str, int]]:
"""Fetches the counts of events in the given range of stream IDs.
Args:
- min_pos (int)
- max_pos (int)
+ min_pos
+ max_pos
Returns:
- Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field
- changes.
+ Mapping of room ID to field changes.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"stats_incremental_total_events_and_bytes",
self.get_changes_room_total_events_and_bytes_txn,
min_pos,
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index aaf225894e..db20a3db30 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,18 +39,27 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
-from typing import Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from twisted.internet import defer
+from synapse.api.filtering import Filter
+from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.engines import PostgresEngine
-from synapse.types import RoomStreamToken
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -68,8 +77,12 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
- direction, column_names, from_token, to_token, engine
-):
+ direction: str,
+ column_names: Tuple[str, str],
+ from_token: Optional[Tuple[int, int]],
+ to_token: Optional[Tuple[int, int]],
+ engine: BaseDatabaseEngine,
+) -> str:
"""Creates an SQL expression to bound the columns by the pagination
tokens.
@@ -90,21 +103,19 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
- direction (str): Whether we're paginating backwards("b") or
- forwards ("f").
- column_names (tuple[str, str]): The column names to bound. Must *not*
- be user defined as these get inserted directly into the SQL
- statement without escapes.
- from_token (tuple[int, int]|None): The start point for the pagination.
- This is an exclusive minimum bound if direction is "f", and an
- inclusive maximum bound if direction is "b".
- to_token (tuple[int, int]|None): The endpoint point for the pagination.
- This is an inclusive maximum bound if direction is "f", and an
- exclusive minimum bound if direction is "b".
+ direction: Whether we're paginating backwards("b") or forwards ("f").
+ column_names: The column names to bound. Must *not* be user defined as
+ these get inserted directly into the SQL statement without escapes.
+ from_token: The start point for the pagination. This is an exclusive
+ minimum bound if direction is "f", and an inclusive maximum bound if
+ direction is "b".
+ to_token: The endpoint point for the pagination. This is an inclusive
+ maximum bound if direction is "f", and an exclusive minimum bound if
+ direction is "b".
engine: The database engine to generate the clauses for
Returns:
- str: The sql expression
+ The sql expression
"""
assert direction in ("b", "f")
@@ -132,7 +143,12 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause)
-def _make_generic_sql_bound(bound, column_names, values, engine):
+def _make_generic_sql_bound(
+ bound: str,
+ column_names: Tuple[str, str],
+ values: Tuple[Optional[int], int],
+ engine: BaseDatabaseEngine,
+) -> str:
"""Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
@@ -142,18 +158,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
out manually.
Args:
- bound (str): The comparison operator to use. One of ">", "<", ">=",
+ bound: The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right.
- names (tuple[str, str]): The column names. Must *not* be user defined
+ names: The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without
escapes.
- values (tuple[int|None, int]): The values to bound the columns by. If
+ values: The values to bound the columns by. If
the first value is None then only creates a bound on the second
column.
engine: The database engine to generate the SQL for
Returns:
- str
+ The SQL statement
"""
assert bound in (">", "<", ">=", "<=")
@@ -193,7 +209,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
)
-def filter_to_clause(event_filter):
+def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -251,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@@ -284,41 +300,42 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._stream_order_on_start = self.get_room_max_stream_ordering()
@abc.abstractmethod
- def get_room_max_stream_ordering(self):
+ def get_room_max_stream_ordering(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
- def get_room_min_stream_ordering(self):
+ def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()
- @defer.inlineCallbacks
- def get_room_events_stream_for_rooms(
- self, room_ids, from_key, to_key, limit=0, order="DESC"
- ):
+ async def get_room_events_stream_for_rooms(
+ self,
+ room_ids: Collection[str],
+ from_key: str,
+ to_key: str,
+ limit: int = 0,
+ order: str = "DESC",
+ ) -> Dict[str, Tuple[List[EventBase], str]]:
"""Get new room events in stream ordering since `from_key`.
Args:
- room_id (str)
- from_key (str): Token from which no events are returned before
- to_key (str): Token from which no events are returned after. (This
+ room_ids
+ from_key: Token from which no events are returned before
+ to_key: Token from which no events are returned after. (This
is typically the current stream token)
- limit (int): Maximum number of events to return
- order (str): Either "DESC" or "ASC". Determines which events are
+ limit: Maximum number of events to return
+ order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
- Deferred[dict[str,tuple[list[FrozenEvent], str]]]
- A map from room id to a tuple containing:
- - list of recent events in the room
- - stream ordering key for the start of the chunk of events returned.
+ A map from room id to a tuple containing:
+ - list of recent events in the room
+ - stream ordering key for the start of the chunk of events returned.
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
- room_ids = yield self._events_stream_cache.get_entities_changed(
- room_ids, from_id
- )
+ room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
if not room_ids:
return {}
@@ -326,7 +343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
- res = yield make_deferred_yieldable(
+ res = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -346,43 +363,47 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- def get_rooms_that_changed(self, room_ids, from_key):
+ def get_rooms_that_changed(
+ self, room_ids: Collection[str], from_key: str
+ ) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
- room_ids (list)
- from_key (str): The room_key portion of a StreamToken
+ room_ids
+ from_key: The room_key portion of a StreamToken
"""
- from_key = RoomStreamToken.parse_stream_token(from_key).stream
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
return {
room_id
for room_id in room_ids
- if self._events_stream_cache.has_entity_changed(room_id, from_key)
+ if self._events_stream_cache.has_entity_changed(room_id, from_id)
}
- @defer.inlineCallbacks
- def get_room_events_stream_for_room(
- self, room_id, from_key, to_key, limit=0, order="DESC"
- ):
-
+ async def get_room_events_stream_for_room(
+ self,
+ room_id: str,
+ from_key: str,
+ to_key: str,
+ limit: int = 0,
+ order: str = "DESC",
+ ) -> Tuple[List[EventBase], str]:
"""Get new room events in stream ordering since `from_key`.
Args:
- room_id (str)
- from_key (str): Token from which no events are returned before
- to_key (str): Token from which no events are returned after. (This
+ room_id
+ from_key: Token from which no events are returned before
+ to_key: Token from which no events are returned after. (This
is typically the current stream token)
- limit (int): Maximum number of events to return
- order (str): Either "DESC" or "ASC". Determines which events are
+ limit: Maximum number of events to return
+ order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
- events (in ascending order) and the token from the start of
- the chunk of events returned.
+ The list of events (in ascending order) and the token from the start
+ of the chunk of events returned.
"""
if from_key == to_key:
return [], from_key
@@ -390,9 +411,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
- has_changed = yield self._events_stream_cache.has_entity_changed(
- room_id, from_id
- )
+ has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
if not has_changed:
return [], from_key
@@ -410,9 +429,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
- rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
+ rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self.get_events_as_list(
+ ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -430,8 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
- @defer.inlineCallbacks
- def get_membership_changes_for_user(self, user_id, from_key, to_key):
+ async def get_membership_changes_for_user(
+ self, user_id: str, from_key: str, to_key: str
+ ) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -460,9 +480,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
- rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
+ rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
- ret = yield self.get_events_as_list(
+ ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -470,27 +490,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
- @defer.inlineCallbacks
- def get_recent_events_for_room(self, room_id, limit, end_token):
+ async def get_recent_events_for_room(
+ self, room_id: str, limit: int, end_token: str
+ ) -> Tuple[List[EventBase], str]:
"""Get the most recent events in the room in topological ordering.
Args:
- room_id (str)
- limit (int)
- end_token (str): The stream token representing now.
+ room_id
+ limit
+ end_token: The stream token representing now.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
- events and a token pointing to the start of the returned
- events.
- The events returned are in ascending order.
+ A list of events and a token pointing to the start of the returned
+ events. The events returned are in ascending order.
"""
- rows, token = yield self.get_recent_event_ids_for_room(
+ rows, token = await self.get_recent_event_ids_for_room(
room_id, limit, end_token
)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -498,20 +517,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
- @defer.inlineCallbacks
- def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+ async def get_recent_event_ids_for_room(
+ self, room_id: str, limit: int, end_token: str
+ ) -> Tuple[List[_EventDictReturn], str]:
"""Get the most recent events in the room in topological ordering.
Args:
- room_id (str)
- limit (int)
- end_token (str): The stream token representing now.
+ room_id
+ limit
+ end_token: The stream token representing now.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
- _EventDictReturn and a token pointing to the start of the returned
- events.
- The events returned are in ascending order.
+ A list of _EventDictReturn and a token pointing to the start of the
+ returned events. The events returned are in ascending order.
"""
# Allow a zero limit here, and no-op.
if limit == 0:
@@ -519,7 +537,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
- rows, token = yield self.db_pool.runInteraction(
+ rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -532,16 +550,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token
- def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+ async def get_room_event_before_stream_ordering(
+ self, room_id: str, stream_ordering: int
+ ) -> Tuple[int, int, str]:
"""Gets details of the first event in a room at or before a stream ordering
Args:
- room_id (str):
- stream_ordering (int):
+ room_id:
+ stream_ordering:
Returns:
- Deferred[(int, int, str)]:
- (stream ordering, topological ordering, event_id)
+ A tuple of (stream ordering, topological ordering, event_id)
"""
def _f(txn):
@@ -556,7 +575,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
- return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
+ return await self.db_pool.runInteraction(
+ "get_room_event_before_stream_ordering", _f
+ )
async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
"""Returns the current token for rooms stream.
@@ -574,57 +595,77 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return "t%d-%d" % (topo, token)
- def get_stream_token_for_event(self, event_id):
+ async def get_stream_id_for_event(self, event_id: str) -> int:
+ """The stream ID for an event
+ Args:
+ event_id: The id of the event to look up a stream token for.
+ Raises:
+ StoreError if the event wasn't in the database.
+ Returns:
+ A stream ID.
+ """
+ return await self.db_pool.runInteraction(
+ "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
+ )
+
+ def get_stream_id_for_event_txn(
+ self, txn: LoggingTransaction, event_id: str, allow_none=False,
+ ) -> int:
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn=txn,
+ table="events",
+ keyvalues={"event_id": event_id},
+ retcol="stream_ordering",
+ allow_none=allow_none,
+ )
+
+ async def get_stream_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
- event_id(str): The id of the event to look up a stream token for.
+ event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A deferred "s%d" stream token.
+ A "s%d" stream token.
"""
- return self.db_pool.simple_select_one_onecol(
- table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
- ).addCallback(lambda row: "s%d" % (row,))
+ stream_id = await self.get_stream_id_for_event(event_id)
+ return "s%d" % (stream_id,)
- def get_topological_token_for_event(self, event_id):
+ async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
- event_id(str): The id of the event to look up a stream token for.
+ event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A deferred "t%d-%d" topological token.
+ A "t%d-%d" topological token.
"""
- return self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
- ).addCallback(
- lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
)
+ return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
- def get_max_topological_token(self, room_id, stream_key):
- """Get the max topological token in a room before the given stream
+ async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
+ """Gets the topological token in a room after or at the given stream
ordering.
Args:
- room_id (str)
- stream_key (int)
-
- Returns:
- Deferred[int]
+ room_id
+ stream_key
"""
sql = (
- "SELECT coalesce(max(topological_ordering), 0) FROM events"
- " WHERE room_id = ? AND stream_ordering < ?"
+ "SELECT coalesce(MIN(topological_ordering), 0) FROM events"
+ " WHERE room_id = ? AND stream_ordering >= ?"
)
- return self.db_pool.execute(
- "get_max_topological_token", None, sql, room_id, stream_key
- ).addCallback(lambda r: r[0][0] if r else 0)
+ row = await self.db_pool.execute(
+ "get_current_topological_token", None, sql, room_id, stream_key
+ )
+ return row[0][0] if row else 0
- def _get_max_topological_txn(self, txn, room_id):
+ def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
txn.execute(
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
@@ -634,16 +675,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows[0][0] if rows else 0
@staticmethod
- def _set_before_and_after(events, rows, topo_order=True):
+ def _set_before_and_after(
+ events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
+ ):
"""Inserts ordering information to events' internal metadata from
the DB rows.
Args:
- events (list[FrozenEvent])
- rows (list[_EventDictReturn])
- topo_order (bool): Whether the events were ordered topologically
- or by stream ordering. If true then all rows should have a non
- null topological_ordering.
+ events
+ rows
+ topo_order: Whether the events were ordered topologically or by stream
+ ordering. If true then all rows should have a non null
+ topological_ordering.
"""
for event, row in zip(events, rows):
stream = row.stream_ordering
@@ -656,25 +699,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (int(topo) if topo else 0, int(stream))
- @defer.inlineCallbacks
- def get_events_around(
- self, room_id, event_id, before_limit, after_limit, event_filter=None
- ):
+ async def get_events_around(
+ self,
+ room_id: str,
+ event_id: str,
+ before_limit: int,
+ after_limit: int,
+ event_filter: Optional[Filter] = None,
+ ) -> dict:
"""Retrieve events and pagination tokens around a given event in a
room.
-
- Args:
- room_id (str)
- event_id (str)
- before_limit (int)
- after_limit (int)
- event_filter (Filter|None)
-
- Returns:
- dict
"""
- results = yield self.db_pool.runInteraction(
+ results = await self.db_pool.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -684,11 +721,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events_before = yield self.get_events_as_list(
+ events_before = await self.get_events_as_list(
list(results["before"]["event_ids"]), get_prev_content=True
)
- events_after = yield self.get_events_as_list(
+ events_after = await self.get_events_as_list(
list(results["after"]["event_ids"]), get_prev_content=True
)
@@ -700,17 +737,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
def _get_events_around_txn(
- self, txn, room_id, event_id, before_limit, after_limit, event_filter
- ):
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ event_id: str,
+ before_limit: int,
+ after_limit: int,
+ event_filter: Optional[Filter],
+ ) -> dict:
"""Retrieves event_ids and pagination tokens around a given event in a
room.
Args:
- room_id (str)
- event_id (str)
- before_limit (int)
- after_limit (int)
- event_filter (Filter|None)
+ room_id
+ event_id
+ before_limit
+ after_limit
+ event_filter
Returns:
dict
@@ -723,6 +766,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
+ # This cannot happen as `allow_none=False`.
+ assert results is not None
+
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
@@ -758,22 +804,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
- @defer.inlineCallbacks
- def get_all_new_events_stream(self, from_id, current_id, limit):
+ async def get_all_new_events_stream(
+ self, from_id: int, current_id: int, limit: int
+ ) -> Tuple[int, List[EventBase]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
Args:
- from_id (int): the stream_ordering of the last event we processed
- current_id (int): the stream_ordering of the most recently processed event
- limit (int): the maximum number of events to return
+ from_id: the stream_ordering of the last event we processed
+ current_id: the stream_ordering of the most recently processed event
+ limit: the maximum number of events to return
Returns:
- Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
- `next_id` is the next value to pass as `from_id` (it will either be the
- stream_ordering of the last returned event, or, if fewer than `limit` events
- were found, `current_id`.
+ A tuple of (next_id, events), where `next_id` is the next value to
+ pass as `from_id` (it will either be the stream_ordering of the
+ last returned event, or, if fewer than `limit` events were found,
+ the `current_id`).
"""
def get_all_new_events_stream_txn(txn):
@@ -795,11 +842,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.db_pool.runInteraction(
+ upper_bound, event_ids = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
- events = yield self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids)
return upper_bound, events
@@ -817,21 +864,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_federation_out_pos",
)
- async def update_federation_out_pos(self, typ, stream_id):
+ async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
if self._need_to_reset_federation_stream_positions:
await self.db_pool.runInteraction(
"_reset_federation_positions_txn", self._reset_federation_positions_txn
)
self._need_to_reset_federation_stream_positions = False
- return await self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
- def _reset_federation_positions_txn(self, txn):
+ def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@@ -870,7 +917,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
GROUP BY type
"""
txn.execute(sql)
- min_positions = dict(txn) # Map from type -> min position
+ min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
# Ensure we do actually have some values here
assert set(min_positions) == {"federation", "events"}
@@ -892,39 +939,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
values={"stream_id": stream_id},
)
- def has_room_changed_since(self, room_id, stream_id):
+ def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(
self,
- txn,
- room_id,
- from_token,
- to_token=None,
- direction="b",
- limit=-1,
- event_filter=None,
- ):
+ txn: LoggingTransaction,
+ room_id: str,
+ from_token: RoomStreamToken,
+ to_token: Optional[RoomStreamToken] = None,
+ direction: str = "b",
+ limit: int = -1,
+ event_filter: Optional[Filter] = None,
+ ) -> Tuple[List[_EventDictReturn], str]:
"""Returns list of events before or after a given token.
Args:
txn
- room_id (str)
- from_token (RoomStreamToken): The token used to stream from
- to_token (RoomStreamToken|None): A token which if given limits the
- results to only those before
- direction(char): Either 'b' or 'f' to indicate whether we are
- paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return.
- event_filter (Filter|None): If provided filters the events to
+ room_id
+ from_token: The token used to stream from
+ to_token: A token which if given limits the results to only those before
+ direction: Either 'b' or 'f' to indicate whether we are paginating
+ forwards or backwards from `from_key`.
+ limit: The maximum number of events to return.
+ event_filter: If provided filters the events to
those that match the filter.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
- as a list of _EventDictReturn and a token that points to the end
- of the result set. If no events are returned then the end of the
- stream has been reached (i.e. there are no events between
- `from_token` and `to_token`), or `limit` is zero.
+ A list of _EventDictReturn and a token that points to the end of the
+ result set. If no events are returned then the end of the stream has
+ been reached (i.e. there are no events between `from_token` and
+ `to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -1008,35 +1053,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, str(next_token)
- @defer.inlineCallbacks
- def paginate_room_events(
- self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
- ):
+ async def paginate_room_events(
+ self,
+ room_id: str,
+ from_key: str,
+ to_key: Optional[str] = None,
+ direction: str = "b",
+ limit: int = -1,
+ event_filter: Optional[Filter] = None,
+ ) -> Tuple[List[EventBase], str]:
"""Returns list of events before or after a given token.
Args:
- room_id (str)
- from_key (str): The token used to stream from
- to_key (str|None): A token which if given limits the results to
- only those before
- direction(char): Either 'b' or 'f' to indicate whether we are
- paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return.
- event_filter (Filter|None): If provided filters the events to
- those that match the filter.
+ room_id
+ from_key: The token used to stream from
+ to_key: A token which if given limits the results to only those before
+ direction: Either 'b' or 'f' to indicate whether we are paginating
+ forwards or backwards from `from_key`.
+ limit: The maximum number of events to return.
+ event_filter: If provided filters the events to those that match the filter.
Returns:
- tuple[list[FrozenEvent], str]: Returns the results as a list of
- events and a token that points to the end of the result set. If no
- events are returned then the end of the stream has been reached
- (i.e. there are no events between `from_key` and `to_key`).
+ The results as a list of events and a token that points to the end
+ of the result set. If no events are returned then the end of the
+ stream has been reached (i.e. there are no events between `from_key`
+ and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
- rows, token = yield self.db_pool.runInteraction(
+ rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
@@ -1047,7 +1095,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -1057,8 +1105,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
class StreamStore(StreamWorkerStore):
- def get_room_max_stream_ordering(self):
+ def get_room_max_stream_ordering(self) -> int:
return self._stream_id_gen.get_current_token()
- def get_room_min_stream_ordering(self):
+ def get_room_min_stream_ordering(self) -> int:
return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e4e0a0c433..96ffe26cc9 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,11 +17,10 @@
import logging
from typing import Dict, List, Tuple
-from canonicaljson import json
-
from synapse.storage._base import db_to_json
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -44,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
- tags_by_room = {}
+ tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
@@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (user_id, room_id))
tags = []
for tag, content in txn:
- tags.append(json.dumps(tag) + ":" + content)
+ tags.append(json_encoder.encode(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, (user_id, room_id, tag_json)))
@@ -124,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags(
self, user_id: str, stream_id: int
- ) -> Dict[str, List[str]]:
+ ) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
@@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
self.db_pool.simple_upsert_txn(
@@ -211,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -233,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 52668dbdf9..5b31aab700 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -15,12 +15,14 @@
import logging
from collections import namedtuple
+from typing import Optional, Tuple
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
+from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
db_binary_type = memoryview
@@ -55,21 +57,23 @@ class TransactionStore(SQLBaseStore):
expiry_ms=5 * 60 * 1000,
)
- def get_received_txn_response(self, transaction_id, origin):
+ async def get_received_txn_response(
+ self, transaction_id: str, origin: str
+ ) -> Optional[Tuple[int, JsonDict]]:
"""For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response
body (as a dict).
Args:
- transaction_id (str)
- origin(str)
+ transaction_id
+ origin
Returns:
- tuple: None if we have not previously responded to
- this transaction or a 2-tuple of (int, dict)
+ None if we have not previously responded to this transaction or a
+ 2-tuple of (int, dict)
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_received_txn_response",
self._get_received_txn_response,
transaction_id,
@@ -98,20 +102,21 @@ class TransactionStore(SQLBaseStore):
else:
return None
- def set_received_txn_response(self, transaction_id, origin, code, response_dict):
- """Persist the response we returened for an incoming transaction, and
+ async def set_received_txn_response(
+ self, transaction_id: str, origin: str, code: int, response_dict: JsonDict
+ ) -> None:
+ """Persist the response we returned for an incoming transaction, and
should return for subsequent transactions with the same transaction_id
and origin.
Args:
- txn
- transaction_id (str)
- origin (str)
- code (int)
- response_json (str)
+ transaction_id: The incoming transaction ID.
+ origin: The origin server.
+ code: The response code.
+ response_dict: The response, to be encoded into JSON.
"""
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
table="received_transactions",
values={
"transaction_id": transaction_id,
@@ -164,21 +169,25 @@ class TransactionStore(SQLBaseStore):
else:
return None
- def set_destination_retry_timings(
- self, destination, failure_ts, retry_last_ts, retry_interval
- ):
+ async def set_destination_retry_timings(
+ self,
+ destination: str,
+ failure_ts: Optional[int],
+ retry_last_ts: int,
+ retry_interval: int,
+ ) -> None:
"""Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring.
Args:
- destination (str)
- failure_ts (int|None) - when the server started failing (ms since epoch)
- retry_last_ts (int) - time of last retry attempt in unix epoch ms
- retry_interval (int) - how long until next retry in ms
+ destination
+ failure_ts: when the server started failing (ms since epoch)
+ retry_last_ts: time of last retry attempt in unix epoch ms
+ retry_interval: how long until next retry in ms
"""
self._destination_retry_cache.pop(destination, None)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
@@ -254,13 +263,13 @@ class TransactionStore(SQLBaseStore):
"cleanup_transactions", self._cleanup_transactions
)
- def _cleanup_transactions(self):
+ async def _cleanup_transactions(self) -> None:
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 37276f73f8..b89668d561 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,15 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import attr
-from canonicaljson import json
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
-from synapse.util import stringutils as stringutils
+from synapse.util import json_encoder, stringutils
@attr.s
@@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore):
StoreError if a unique session ID cannot be generated.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -143,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore):
await self.db_pool.simple_upsert(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id, "stage_type": stage_type},
- values={"result": json.dumps(result)},
+ values={"result": json_encoder.encode(result)},
desc="mark_ui_auth_stage_complete",
)
except self.db_pool.engine.module.IntegrityError:
@@ -184,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore):
The dictionary from the client root level, not the 'auth' key.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
await self.db_pool.simple_update_one(
table="ui_auth_sessions",
@@ -214,14 +214,16 @@ class UIAuthWorkerStore(SQLBaseStore):
value,
)
- def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+ def _set_ui_auth_session_data_txn(
+ self, txn: LoggingTransaction, session_id: str, key: str, value: Any
+ ):
# Get the current value.
result = self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
- )
+ ) # type: Dict[str, Any] # type: ignore
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
@@ -231,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
- updatevalues={"serverdict": json.dumps(serverdict)},
+ updatevalues={"serverdict": json_encoder.encode(serverdict)},
)
async def get_ui_auth_session_data(
@@ -258,9 +260,37 @@ class UIAuthWorkerStore(SQLBaseStore):
return serverdict.get(key, default)
+ async def add_user_agent_ip_to_ui_auth_session(
+ self, session_id: str, user_agent: str, ip: str,
+ ):
+ """Add the given user agent / IP to the tracking table
+ """
+ await self.db_pool.simple_upsert(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
+ values={},
+ desc="add_user_agent_ip_to_ui_auth_session",
+ )
+
+ async def get_user_agents_ips_to_ui_auth_session(
+ self, session_id: str,
+ ) -> List[Tuple[str, str]]:
+ """Get the given user agents / IPs used during the ui auth process
+
+ Returns:
+ List of user_agent/ip pairs
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ )
+ return [(row["user_agent"], row["ip"]) for row in rows]
+
class UIAuthStore(UIAuthWorkerStore):
- def delete_old_ui_auth_sessions(self, expiration_time: int):
+ async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@@ -269,18 +299,29 @@ class UIAuthStore(UIAuthWorkerStore):
This is an epoch time in milliseconds.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_old_ui_auth_sessions",
self._delete_old_ui_auth_sessions_txn,
expiration_time,
)
- def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+ def _delete_old_ui_auth_sessions_txn(
+ self, txn: LoggingTransaction, expiration_time: int
+ ):
# Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time])
session_ids = [r[0] for r in txn.fetchall()]
+ # Delete the corresponding IP/user agents.
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions_ips",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index af21fe457a..f2f9a5799a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,6 +15,7 @@
import logging
import re
+from typing import Any, Dict, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@@ -364,10 +365,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False
- def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
+ async def update_profile_in_user_dir(
+ self, user_id: str, display_name: str, avatar_url: str
+ ) -> None:
"""
Update or add a user's profile in the user directory.
"""
+ # If the display name or avatar URL are unexpected types, overwrite them.
+ if not isinstance(display_name, str):
+ display_name = None
+ if not isinstance(avatar_url, str):
+ avatar_url = None
def _update_profile_in_user_dir_txn(txn):
new_entry = self.db_pool.simple_upsert_txn(
@@ -457,17 +465,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
- def add_users_who_share_private_room(self, room_id, user_id_tuples):
+ async def add_users_who_share_private_room(
+ self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
+ ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
- room_id (str)
- user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ room_id
+ user_id_tuples: iterable of 2-tuple of user IDs.
"""
def _add_users_who_share_room_txn(txn):
@@ -483,17 +493,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
- def add_users_in_public_rooms(self, room_id, user_ids):
+ async def add_users_in_public_rooms(
+ self, room_id: str, user_ids: Iterable[str]
+ ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
- room_id (str)
- user_ids (list[str])
+ room_id
+ user_ids
"""
def _add_users_in_public_rooms_txn(txn):
@@ -507,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
- def delete_all_from_user_dir(self):
+ async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory
"""
@@ -522,13 +534,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
- def get_user_in_directory(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@@ -536,8 +548,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
- def update_user_directory_stream_pos(self, stream_id):
- return self.db_pool.simple_update_one(
+ async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+ await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
@@ -554,7 +566,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
- def remove_from_user_dir(self, user_id):
+ async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
@@ -577,7 +589,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"remove_from_user_dir", _remove_from_user_dir_txn
)
@@ -604,14 +616,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return user_ids
- def remove_user_who_share_room(self, user_id, room_id):
+ async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
"""
Deletes entries in the users_who_share_*_rooms table. The first
user should be a local user.
Args:
- user_id (str)
- room_id (str)
+ user_id
+ room_id
"""
def _remove_user_who_share_room_txn(txn):
@@ -631,7 +643,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
keyvalues={"user_id": user_id, "room_id": room_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
@@ -663,8 +675,50 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- def get_user_directory_stream_pos(self):
- return self.db_pool.simple_select_one_onecol(
+ @cached()
+ async def get_shared_rooms_for_users(
+ self, user_id: str, other_user_id: str
+ ) -> Set[str]:
+ """
+ Returns the rooms that a local user shares with another local or remote user.
+
+ Args:
+ user_id: The MXID of a local user
+ other_user_id: The MXID of the other user
+
+ Returns:
+ A set of room ID's that the users share.
+ """
+
+ def _get_shared_rooms_for_users_txn(txn):
+ txn.execute(
+ """
+ SELECT p1.room_id
+ FROM users_in_public_rooms as p1
+ INNER JOIN users_in_public_rooms as p2
+ ON p1.room_id = p2.room_id
+ AND p1.user_id = ?
+ AND p2.user_id = ?
+ UNION
+ SELECT room_id
+ FROM users_who_share_private_rooms
+ WHERE
+ user_id = ?
+ AND other_user_id = ?
+ """,
+ (user_id, other_user_id, user_id, other_user_id),
+ )
+ rows = self.db_pool.cursor_to_dict(txn)
+ return rows
+
+ rows = await self.db_pool.runInteraction(
+ "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+ )
+
+ return {row["room_id"] for row in rows}
+
+ async def get_user_directory_stream_pos(self) -> int:
+ return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index ab6cb2c1f6..2f7c95fc74 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -13,35 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import operator
-
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
class UserErasureWorkerStore(SQLBaseStore):
@cached()
- def is_user_erased(self, user_id):
+ async def is_user_erased(self, user_id: str) -> bool:
"""
Check if the given user id has requested erasure
Args:
- user_id (str): full user id to check
+ user_id: full user id to check
Returns:
- Deferred[bool]: True if the user has requested erasure
+ True if the user has requested erasure
"""
- return self.db_pool.simple_select_onecol(
+ result = await self.db_pool.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
desc="is_user_erased",
- ).addCallback(operator.truth)
+ )
+ return bool(result)
- @cachedList(
- cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
- )
- def are_users_erased(self, user_ids):
+ @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
+ async def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure
@@ -49,14 +46,14 @@ class UserErasureWorkerStore(SQLBaseStore):
user_ids (iterable[str]): full user id to check
Returns:
- Deferred[dict[str, bool]]:
+ dict[str, bool]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@@ -65,12 +62,11 @@ class UserErasureWorkerStore(SQLBaseStore):
)
erased_users = {row["user_id"] for row in rows}
- res = {u: u in erased_users for u in user_ids}
- return res
+ return {u: u in erased_users for u in user_ids}
class UserErasureStore(UserErasureWorkerStore):
- def mark_user_erased(self, user_id: str) -> None:
+ async def mark_user_erased(self, user_id: str) -> None:
"""Indicate that user_id wishes their message history to be erased.
Args:
@@ -88,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db_pool.runInteraction("mark_user_erased", f)
+ await self.db_pool.runInteraction("mark_user_erased", f)
- def mark_user_not_erased(self, user_id: str) -> None:
+ async def mark_user_not_erased(self, user_id: str) -> None:
"""Indicate that user_id is no longer erased.
Args:
@@ -110,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db_pool.runInteraction("mark_user_not_erased", f)
+ await self.db_pool.runInteraction("mark_user_not_erased", f)
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7f104ad936..e924f1ca3b 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -17,8 +17,6 @@ import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -103,7 +101,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@cached(max_entries=10000, iterable=True)
- def get_state_group_delta(self, state_group):
+ async def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.
@@ -135,7 +133,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_state_group_delta", _get_state_group_delta_txn
)
@@ -367,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types,
)
- def store_state_group(
+ async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
- ):
+ ) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
@@ -383,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to event_id.
Returns:
- Deferred[int]: The state group ID
+ The state group ID
"""
def _store_state_group_txn(txn):
@@ -484,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_group
- return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
+ return await self.db_pool.runInteraction(
+ "store_state_group", _store_state_group_txn
+ )
- def purge_unreferenced_state_groups(
+ async def purge_unreferenced_state_groups(
self, room_id: str, state_groups_to_delete
- ) -> defer.Deferred:
+ ) -> None:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.
@@ -499,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to delete.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups,
room_id,
@@ -594,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return {row["state_group"]: row["prev_state_group"] for row in rows}
- def purge_room_state(self, room_id, state_groups_to_delete):
+ async def purge_room_state(self, room_id, state_groups_to_delete):
"""Deletes all record of a room from state tables
Args:
@@ -602,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete (list[int]): State groups to delete
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,
|