diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0ac854aee2..c73d54fb67 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -68,7 +68,7 @@ class Databases(object):
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
- if hs.config.worker.writers.events == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.events:
persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases:
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 70cf15dd7f..99890ffbf3 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@
import calendar
import logging
import time
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
@@ -264,6 +264,9 @@ class DataStore(
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
@@ -291,16 +294,16 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
- def count_daily_users(self):
+ async def count_daily_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_daily_users", self._count_users, yesterday
)
- def count_monthly_users(self):
+ async def count_monthly_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 30 days.
Note this method is intended for phonehome metrics only and is different
@@ -308,7 +311,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)
@@ -327,15 +330,15 @@ class DataStore(
(count,) = txn.fetchone()
return count
- def count_r30_users(self):
+ async def count_r30_users(self) -> Dict[str, int]:
"""
Counts the number of 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days ago
* Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days apart
- Returns counts globaly for a given user as well as breaking
- by platform
+ Returns:
+ A mapping of counts globally as well as broken out by platform.
"""
def _count_r30_users(txn):
@@ -408,7 +411,7 @@ class DataStore(
return results
- return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+ return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self):
"""
@@ -418,7 +421,7 @@ class DataStore(
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000
- def generate_user_daily_visits(self):
+ async def generate_user_daily_visits(self) -> None:
"""
Generates daily visit data for use in cohort/ retention analysis
"""
@@ -473,7 +476,7 @@ class DataStore(
# frequently
self._last_user_visit_update = now
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
@@ -497,22 +500,28 @@ class DataStore(
desc="get_users",
)
- def get_users_paginate(
- self, start, limit, user_id=None, name=None, guests=True, deactivated=False
- ):
+ async def get_users_paginate(
+ self,
+ start: int,
+ limit: int,
+ user_id: Optional[str] = None,
+ name: Optional[str] = None,
+ guests: bool = True,
+ deactivated: bool = False,
+ ) -> Tuple[List[Dict[str, Any]], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
Args:
- start (int): start number to begin the query from
- limit (int): number of rows to retrieve
- user_id (string): search for user_id. ignored if name is not None
- name (string): search for local part of user_id or display name
- guests (bool): whether to in include guest users
- deactivated (bool): whether to include deactivated users
+ start: start number to begin the query from
+ limit: number of rows to retrieve
+ user_id: search for user_id. ignored if name is not None
+ name: search for local part of user_id or display name
+ guests: whether to in include guest users
+ deactivated: whether to include deactivated users
Returns:
- defer.Deferred: resolves to list[dict[str, Any]], int
+ A tuple of a list of mappings from user to information and a count of total users.
"""
def get_users_paginate_txn(txn):
@@ -555,7 +564,7 @@ class DataStore(
users = self.db_pool.cursor_to_dict(txn)
return users, count
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_paginate_txn", get_users_paginate_txn
)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 04042a2c98..4436b1a83d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,9 +16,7 @@
import abc
import logging
-from typing import List, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Dict, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
@@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cached()
- def get_account_data_for_user(self, user_id):
+ async def get_account_data_for_user(
+ self, user_id: str
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user.
Args:
- user_id(str): The user to get the account_data for.
+ user_id: The user to get the account_data for.
Returns:
- A deferred pair of a dict of global account_data and a dict
- mapping from room_id string to per room account_data dicts.
+ A 2-tuple of a dict of global account_data and a dict mapping from
+ room_id string to per room account_data dicts.
"""
def get_account_data_for_user_txn(txn):
@@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
return None
@cached(num_args=2)
- def get_account_data_for_room(self, user_id, room_id):
+ async def get_account_data_for_room(
+ self, user_id: str, room_id: str
+ ) -> Dict[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
- user_id(str): The user to get the account_data for.
- room_id(str): The room to get the account_data for.
+ user_id: The user to get the account_data for.
+ room_id: The room to get the account_data for.
Returns:
- A deferred dict of the room account_data
+ A dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
@@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@cached(num_args=3, max_entries=5000)
- def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+ async def get_account_data_for_room_and_type(
+ self, user_id: str, room_id: str, account_data_type: str
+ ) -> Optional[JsonDict]:
"""Get the client account_data of given type for a user for a room.
Args:
- user_id(str): The user to get the account_data for.
- room_id(str): The room to get the account_data for.
- account_data_type (str): The account data type to get.
+ user_id: The user to get the account_data for.
+ room_id: The room to get the account_data for.
+ account_data_type: The account data type to get.
Returns:
- A deferred of the room account_data for that type, or None if
- there isn't any set.
+ The room account_data for that type, or None if there isn't any set.
"""
def get_account_data_for_room_and_type_txn(txn):
@@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_updated_room_account_data", get_updated_room_account_data_txn
)
- def get_updated_account_data_for_user(self, user_id, stream_id):
+ async def get_updated_account_data_for_user(
+ self, user_id: str, stream_id: int
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a that's changed for a user
Args:
- user_id(str): The user to get the account_data for.
- stream_id(int): The point in the stream since which to get updates
+ user_id: The user to get the account_data for.
+ stream_id: The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
@@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
- return defer.succeed(({}, {}))
+ return ({}, {})
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
- def _update_max_stream_id(self, next_id: int):
+ async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id
Args:
@@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
- return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+ await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 4e2b2a85ee..d568789124 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
@wrap_as_background_process("update_client_ips")
- def _update_client_ips_batch(self):
+ async def _update_client_ips_batch(self) -> None:
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
to_update = self._batch_row_update
self._batch_row_update = {}
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index def96637a2..f8fe948122 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
"""
- now_stream_id = self._device_list_id_gen.get_current_token()
+ now_stream_id = self.get_device_stream_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -254,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
List of objects representing an device update EDU
"""
devices = (
- await self.db_pool.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
+ await self.get_e2e_device_keys_and_signatures(
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
@@ -292,17 +291,17 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = stream_id
if device is not None:
- key_json = device.get("key_json", None)
+ key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
+ if device.signatures:
+ for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
- device_display_name = device.get("device_display_name", None)
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
@@ -312,9 +311,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results
- def _get_last_device_update_for_remote_user(
+ async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int
- ):
+ ) -> int:
def f(txn):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
@@ -325,12 +324,16 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
- return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
+ return await self.db_pool.runInteraction(
+ "get_last_device_update_for_remote_user", f
+ )
- def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
+ async def mark_as_sent_devices_by_remote(
+ self, destination: str, stream_id: int
+ ) -> None:
"""Mark that updates have successfully been sent to the destination.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
@@ -412,8 +415,10 @@ class DeviceWorkerStore(SQLBaseStore):
},
)
+ @abc.abstractmethod
def get_device_stream_token(self) -> int:
- return self._device_list_id_gen.get_current_token()
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
@trace
async def get_user_devices_from_cache(
@@ -481,51 +486,6 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
- def get_devices_with_keys_by_user(self, user_id: str):
- """Get all devices (with any device keys) for a user
-
- Returns:
- Deferred which resolves to (stream_id, devices)
- """
- return self.db_pool.runInteraction(
- "get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn,
- user_id,
- )
-
- def _get_devices_with_keys_by_user_txn(
- self, txn: LoggingTransaction, user_id: str
- ) -> Tuple[int, List[JsonDict]]:
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
-
- if devices:
- user_devices = devices[user_id]
- results = []
- for device_id, device in user_devices.items():
- result = {"device_id": device_id}
-
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
-
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
-
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
-
- results.append(result)
-
- return now_stream_id, results
-
- return now_stream_id, []
-
async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
) -> Set[str]:
@@ -726,7 +686,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
- def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
+ async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user.
"""
@@ -740,7 +700,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn,
)
@@ -1001,9 +961,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
desc="update_device",
)
- def update_remote_device_list_cache_entry(
+ async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
- ):
+ ) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's
@@ -1014,11 +974,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: ID of decivice being updated
content: new data on this device
stream_id: the version of the device list
-
- Returns:
- Deferred[None]
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@@ -1070,9 +1027,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
lock=False,
)
- def update_remote_device_list_cache(
+ async def update_remote_device_list_cache(
self, user_id: str, devices: List[dict], stream_id: int
- ):
+ ) -> None:
"""Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's
@@ -1082,11 +1039,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: User to update device list for
devices: list of device objects supplied over federation
stream_id: the version of the device list
-
- Returns:
- Deferred[None]
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@@ -1096,7 +1050,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
- ):
+ ) -> None:
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 405b5eafa5..e5060d4c46 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
- def update_aliases_for_room(
+ async def update_aliases_for_room(
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
- ):
+ ) -> None:
"""Repoint all of the aliases for a given room, to a different room.
Args:
@@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index af0b85e2c9..cc0b15ae07 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,8 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+import attr
from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
@@ -23,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -31,19 +34,67 @@ if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
+@attr.s
+class DeviceKeyLookupResult:
+ """The type returned by get_e2e_device_keys_and_signatures"""
+
+ display_name = attr.ib(type=Optional[str])
+
+ # the key data from e2e_device_keys_json. Typically includes fields like
+ # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
+ # key) and "signatures" (a signature of the structure by the ed25519 key)
+ key_json = attr.ib(type=Optional[str])
+
+ # cross-signing sigs
+ signatures = attr.ib(type=Optional[Dict], default=None)
+
+
class EndToEndKeyWorkerStore(SQLBaseStore):
+ async def get_e2e_device_keys_for_federation_query(
+ self, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
+ """Get all devices (with any device keys) for a user
+
+ Returns:
+ (stream_id, devices)
+ """
+ now_stream_id = self.get_device_stream_token()
+
+ devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
+
+ if devices:
+ user_devices = devices[user_id]
+ results = []
+ for device_id, device in user_devices.items():
+ result = {"device_id": device_id}
+
+ key_json = device.key_json
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+
+ if device.signatures:
+ for sig_user_id, sigs in device.signatures.items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
+ device_display_name = device.display_name
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
@trace
- async def get_e2e_device_keys(
- self, query_list, include_all_devices=False, include_deleted_devices=False
- ):
- """Fetch a list of device keys.
+ async def get_e2e_device_keys_for_cs_api(
+ self, query_list: List[Tuple[str, Optional[str]]]
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ """Fetch a list of device keys, formatted suitably for the C/S API.
Args:
query_list(list): List of pairs of user_ids and device_ids.
- include_all_devices (bool): whether to include entries for devices
- that don't have device keys
- include_deleted_devices (bool): whether to include null entries for
- devices which no longer exist (but were in the query_list).
- This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data. The key data will be a dict in the same format as the
@@ -53,13 +104,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
- results = await self.db_pool.runInteraction(
- "get_e2e_device_keys",
- self._get_e2e_device_keys_txn,
- query_list,
- include_all_devices,
- include_deleted_devices,
- )
+ results = await self.get_e2e_device_keys_and_signatures(query_list)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
@@ -67,13 +112,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
- r = db_to_json(device_info.pop("key_json"))
+ r = db_to_json(device_info.key_json)
r["unsigned"] = {}
- display_name = device_info["device_display_name"]
+ display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
- if "signatures" in device_info:
- for sig_user_id, sigs in device_info["signatures"].items():
+ if device_info.signatures:
+ for sig_user_id, sigs in device_info.signatures.items():
r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
@@ -82,12 +127,45 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return rv
@trace
- def _get_e2e_device_keys_txn(
- self, txn, query_list, include_all_devices=False, include_deleted_devices=False
- ):
+ async def get_e2e_device_keys_and_signatures(
+ self,
+ query_list: List[Tuple[str, Optional[str]]],
+ include_all_devices: bool = False,
+ include_deleted_devices: bool = False,
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ """Fetch a list of device keys, together with their cross-signatures.
+
+ Args:
+ query_list: List of pairs of user_ids and device_ids. Device id can be None
+ to indicate "all devices for this user"
+
+ include_all_devices: whether to return devices without device keys
+
+ include_deleted_devices: whether to include null entries for
+ devices which no longer exist (but were in the query_list).
+ This option only takes effect if include_all_devices is true.
+
+ Returns:
+ Dict mapping from user-id to dict mapping from device_id to
+ key data.
+ """
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
+ result = await self.db_pool.runInteraction(
+ "get_e2e_device_keys",
+ self._get_e2e_device_keys_and_signatures_txn,
+ query_list,
+ include_all_devices,
+ include_deleted_devices,
+ )
+
+ log_kv(result)
+ return result
+
+ def _get_e2e_device_keys_and_signatures_txn(
+ self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
query_clauses = []
query_params = []
signature_query_clauses = []
@@ -119,7 +197,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
sql = (
"SELECT user_id, device_id, "
- " d.display_name AS device_display_name, "
+ " d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@@ -130,13 +208,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
- rows = self.db_pool.cursor_to_dict(txn)
- result = {}
- for row in rows:
+ result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+ for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
- deleted_devices.remove((row["user_id"], row["device_id"]))
- result.setdefault(row["user_id"], {})[row["device_id"]] = row
+ deleted_devices.remove((user_id, device_id))
+ result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+ display_name, key_json
+ )
if include_deleted_devices:
for user_id, device_id in deleted_devices:
@@ -167,13 +246,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# note that target_device_result will be None for deleted devices.
continue
- target_device_signatures = target_device_result.setdefault("signatures", {})
+ target_device_signatures = target_device_result.signatures
+ if target_device_signatures is None:
+ target_device_signatures = target_device_result.signatures = {}
+
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
signing_user_signatures[signing_key_id] = signature
- log_kv(result)
return result
async def get_e2e_one_time_keys(
@@ -252,10 +333,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
- def count_e2e_one_time_keys(self, user_id, device_id):
+ async def count_e2e_one_time_keys(
+ self, user_id: str, device_id: str
+ ) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device
Returns:
- Dict mapping from algorithm to number of keys for that algorithm.
+ A mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
@@ -270,7 +353,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
@@ -308,7 +391,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
list_name="user_ids",
num_args=1,
)
- def _get_bare_e2e_cross_signing_keys_bulk(
+ async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
@@ -316,16 +399,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the signatures for the calling user need to be fetched.
Args:
- user_ids (list[str]): the users whose keys are being requested
+ user_ids: the users whose keys are being requested
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. If a user's cross-signing keys were not found, either
- their user ID will not be in the dict, or their user ID will map
- to None.
+ A mapping from user ID to key type to key data. If a user's cross-signing
+ keys were not found, either their user ID will not be in the dict, or
+ their user ID will map to None.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
@@ -541,9 +623,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
_get_all_user_signature_changes_for_remotes_txn,
)
+ @abc.abstractmethod
+ def get_device_stream_token(self) -> int:
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
- def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+ async def set_e2e_device_keys(
+ self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+ ) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
@@ -579,12 +668,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
- def claim_e2e_one_time_keys(self, query_list):
- """Take a list of one time keys out of the database"""
+ async def claim_e2e_one_time_keys(
+ self, query_list: Iterable[Tuple[str, str, str]]
+ ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+ """Take a list of one time keys out of the database.
+
+ Args:
+ query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+ """
@trace
def _claim_e2e_one_time_keys(txn):
@@ -620,11 +718,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
- def delete_e2e_keys_by_device(self, user_id, device_id):
+ async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
@@ -647,7 +745,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 0b69aa6a94..4c3c162acf 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
if stream_ordering <= self.stream_ordering_month_ago:
- raise StoreError(400, "stream_ordering too old")
+ raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
SELECT event_id FROM stream_ordering_to_exterm
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index e8834b2162..001d06378d 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,7 +15,9 @@
# limitations under the License.
import logging
-from typing import List
+from typing import Dict, List, Optional, Tuple, Union
+
+import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
- self, room_id, user_id, last_read_event_id
- ):
+ self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+ ) -> Dict[str, int]:
+ """Get the notification count, the highlight count and the unread message count
+ for a given user in a given room after the given read receipt.
+
+ Note that this function assumes the user to be a current member of the room,
+ since it's either called by the sync handler to handle joined room entries, or by
+ the HTTP pusher to calculate the badge of unread joined rooms.
+
+ Args:
+ room_id: The room to retrieve the counts in.
+ user_id: The user to retrieve the counts for.
+ last_read_event_id: The event associated with the latest read receipt for
+ this user in this room. None if no receipt for this user in this room.
+
+ Returns
+ A dict containing the counts mentioned earlier in this docstring,
+ respectively under the keys "notify_count", "highlight_count" and
+ "unread_count".
+ """
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
@@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _get_unread_counts_by_receipt_txn(
- self, txn, room_id, user_id, last_read_event_id
+ self, txn, room_id, user_id, last_read_event_id,
):
- sql = (
- "SELECT stream_ordering"
- " FROM events"
- " WHERE room_id = ? AND event_id = ?"
- )
- txn.execute(sql, (room_id, last_read_event_id))
- results = txn.fetchall()
- if len(results) == 0:
- return {"notify_count": 0, "highlight_count": 0}
+ stream_ordering = None
+
+ if last_read_event_id is not None:
+ stream_ordering = self.get_stream_id_for_event_txn(
+ txn, last_read_event_id, allow_none=True,
+ )
+
+ if stream_ordering is None:
+ # Either last_read_event_id is None, or it's an event we don't have (e.g.
+ # because it's been purged), in which case retrieve the stream ordering for
+ # the latest membership event from this user in this room (which we assume is
+ # a join).
+ event_id = self.db_pool.simple_select_one_onecol_txn(
+ txn=txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ retcol="event_id",
+ )
- stream_ordering = results[0][0]
+ stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
-
- # First get number of notifications.
- # We don't need to put a notif=1 clause as all rows always have
- # notif=1
sql = (
- "SELECT count(*)"
+ "SELECT"
+ " COUNT(CASE WHEN notif = 1 THEN 1 END),"
+ " COUNT(CASE WHEN highlight = 1 THEN 1 END),"
+ " COUNT(CASE WHEN unread = 1 THEN 1 END)"
" FROM event_push_actions ea"
- " WHERE"
- " user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
+ " WHERE user_id = ?"
+ " AND room_id = ?"
+ " AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
- notify_count = row[0] if row else 0
+
+ (notif_count, highlight_count, unread_count) = (0, 0, 0)
+
+ if row:
+ (notif_count, highlight_count, unread_count) = row
txn.execute(
"""
- SELECT notif_count FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
- """,
+ SELECT notif_count, unread_count FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+ """,
(room_id, user_id, stream_ordering),
)
- rows = txn.fetchall()
- if rows:
- notify_count += rows[0][0]
-
- # Now get the number of highlights
- sql = (
- "SELECT count(*)"
- " FROM event_push_actions ea"
- " WHERE"
- " highlight = 1"
- " AND user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
- )
-
- txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
- highlight_count = row[0] if row else 0
- return {"notify_count": notify_count, "highlight_count": highlight_count}
+ if row:
+ notif_count += row[0]
+ unread_count += row[1]
+
+ return {
+ "notify_count": notif_count,
+ "unread_count": unread_count,
+ "highlight_count": highlight_count,
+ }
async def get_push_action_users_in_range(
self, min_stream_ordering, max_stream_ordering
@@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -383,62 +409,66 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Now return the first `limit`
return notifs[:limit]
- def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
+ async def get_if_maybe_push_in_range_for_user(
+ self, user_id: str, min_stream_ordering: int
+ ) -> bool:
"""A fast check to see if there might be something to push for the
user since the given stream ordering. May return false positives.
Useful to know whether to bother starting a pusher on start up or not.
Args:
- user_id (str)
- min_stream_ordering (int)
+ user_id
+ min_stream_ordering
Returns:
- Deferred[bool]: True if there may be push to process, False if
- there definitely isn't.
+ True if there may be push to process, False if there definitely isn't.
"""
def _get_if_maybe_push_in_range_for_user_txn(txn):
sql = """
SELECT 1 FROM event_push_actions
- WHERE user_id = ? AND stream_ordering > ?
+ WHERE user_id = ? AND stream_ordering > ? AND notif = 1
LIMIT 1
"""
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
- async def add_push_actions_to_staging(self, event_id, user_id_actions):
+ async def add_push_actions_to_staging(
+ self,
+ event_id: str,
+ user_id_actions: Dict[str, List[Union[dict, str]]],
+ count_as_unread: bool,
+ ) -> None:
"""Add the push actions for the event to the push action staging area.
Args:
- event_id (str)
- user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
- user_id to list of push actions, where an action can either be
- a string or dict.
-
- Returns:
- Deferred
+ event_id
+ user_id_actions: A mapping of user_id to list of push actions, where
+ an action can either be a string or dict.
+ count_as_unread: Whether this event should increment unread counts.
"""
-
if not user_id_actions:
return
# This is a helper function for generating the necessary tuple that
- # can be used to inert into the `event_push_actions_staging` table.
+ # can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(user_id, actions):
is_highlight = 1 if _action_has_highlight(actions) else 0
+ notif = 1 if "notify" in actions else 0
return (
event_id, # event_id column
user_id, # user_id column
_serialize_action(actions, is_highlight), # actions column
- 1, # notif column
+ notif, # notif column
is_highlight, # highlight column
+ int(count_as_unread), # unread column
)
def _add_push_actions_to_staging_txn(txn):
@@ -447,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql = """
INSERT INTO event_push_actions_staging
- (event_id, user_id, actions, notif, highlight)
- VALUES (?, ?, ?, ?, ?)
+ (event_id, user_id, actions, notif, highlight, unread)
+ VALUES (?, ?, ?, ?, ?, ?)
"""
txn.executemany(
@@ -507,7 +537,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
)
- def find_first_stream_ordering_after_ts(self, ts):
+ async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
"""Gets the stream ordering corresponding to a given timestamp.
Specifically, finds the stream_ordering of the first event that was
@@ -516,13 +546,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
relatively slow.
Args:
- ts (int): timestamp in millis
+ ts: timestamp in millis
Returns:
- Deferred[int]: stream ordering of the first event received on/after
- the timestamp
+ stream ordering of the first event received on/after the timestamp
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@@ -813,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
- coalesce(old.notif_count, 0) + upd.notif_count,
+ coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering,
old.user_id
FROM (
- SELECT user_id, room_id, count(*) as notif_count,
+ SELECT user_id, room_id, count(*) as cnt,
max(stream_ordering) as stream_ordering
FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0
+ AND %s = 1
GROUP BY user_id, room_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
- txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
- rows = txn.fetchall()
+ # First get the count of unread messages.
+ txn.execute(
+ sql % ("unread_count", "unread"),
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ # We need to merge results from the two requests (the one that retrieves the
+ # unread count and the one that retrieves the notifications count) into a single
+ # object because we might not have the same amount of rows in each of them. To do
+ # this, we use a dict indexed on the user ID and room ID to make it easier to
+ # populate.
+ summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
+ for row in txn:
+ summaries[(row[0], row[1])] = _EventPushSummary(
+ unread_count=row[2],
+ stream_ordering=row[3],
+ old_user_id=row[4],
+ notif_count=0,
+ )
+
+ # Then get the count of notifications.
+ txn.execute(
+ sql % ("notif_count", "notif"),
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ for row in txn:
+ if (row[0], row[1]) in summaries:
+ summaries[(row[0], row[1])].notif_count = row[2]
+ else:
+ # Because the rules on notifying are different than the rules on marking
+ # a message unread, we might end up with messages that notify but aren't
+ # marked unread, so we might not have a summary for this (user, room)
+ # tuple to complete.
+ summaries[(row[0], row[1])] = _EventPushSummary(
+ unread_count=0,
+ stream_ordering=row[3],
+ old_user_id=row[4],
+ notif_count=row[2],
+ )
- logger.info("Rotating notifications, handling %d rows", len(rows))
+ logger.info("Rotating notifications, handling %d rows", len(summaries))
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
@@ -840,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
table="event_push_summary",
values=[
{
- "user_id": row[0],
- "room_id": row[1],
- "notif_count": row[2],
- "stream_ordering": row[3],
+ "user_id": user_id,
+ "room_id": room_id,
+ "notif_count": summary.notif_count,
+ "unread_count": summary.unread_count,
+ "stream_ordering": summary.stream_ordering,
}
- for row in rows
- if row[4] is None
+ for ((user_id, room_id), summary) in summaries.items()
+ if summary.old_user_id is None
],
)
txn.executemany(
"""
- UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+ UPDATE event_push_summary
+ SET notif_count = ?, unread_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
""",
- ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
+ (
+ (
+ summary.notif_count,
+ summary.unread_count,
+ summary.stream_ordering,
+ user_id,
+ room_id,
+ )
+ for ((user_id, room_id), summary) in summaries.items()
+ if summary.old_user_id is not None
+ ),
)
txn.execute(
@@ -881,3 +961,15 @@ def _action_has_highlight(actions):
pass
return False
+
+
+@attr.s
+class _EventPushSummary:
+ """Summary of pending event push actions for a given user in a given room.
+ Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+ """
+
+ unread_count = attr.ib(type=int)
+ stream_ordering = attr.ib(type=int)
+ old_user_id = attr.ib(type=str)
+ notif_count = attr.ib(type=int)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 6313b41eef..b94fe7ac17 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -97,6 +97,7 @@ class PersistEventsStore:
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@@ -108,7 +109,7 @@ class PersistEventsStore:
# This should only exist on instances that are configured to write
assert (
- hs.config.worker.writers.events == hs.get_instance_name()
+ hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates(
@@ -800,6 +801,7 @@ class PersistEventsStore:
table="events",
values=[
{
+ "instance_name": self._instance_name,
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth,
"depth": event.depth,
@@ -1296,9 +1298,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
- topological_ordering, notif, highlight
+ topological_ordering, notif, highlight, unread
)
- SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
FROM event_push_actions_staging
WHERE event_id = ?
"""
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e6247d682d..17f5997b89 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
@@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
- if hs.config.worker.writers.events == hs.get_instance_name():
- # We are the process in charge of generating stream ids for events,
- # so instantiate ID generators based on the database
- self._stream_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering",
+ if isinstance(database.engine, PostgresEngine):
+ # If we're using Postgres than we can use `MultiWriterIdGenerator`
+ # regardless of whether this process writes to the streams or not.
+ self._stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ instance_name=hs.get_instance_name(),
+ table="events",
+ instance_column="instance_name",
+ id_column="stream_ordering",
+ sequence_name="events_stream_seq",
)
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ self._backfill_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ instance_name=hs.get_instance_name(),
+ table="events",
+ instance_column="instance_name",
+ id_column="stream_ordering",
+ sequence_name="events_backfill_stream_seq",
+ positive=False,
)
else:
- # Another process is in charge of persisting events and generating
- # stream IDs: rely on the replication streams to let us know which
- # IDs we can process.
- self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
- self._backfill_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering", step=-1
- )
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn, "events", "stream_ordering",
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ )
+ else:
+ self._stream_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering"
+ )
+ self._backfill_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering", step=-1
+ )
self._get_event_cache = Cache(
"*getEvent*",
@@ -823,20 +851,24 @@ class EventsWorkerStore(SQLBaseStore):
return event_dict
- def _maybe_redact_event_row(self, original_ev, redactions, event_map):
+ def _maybe_redact_event_row(
+ self,
+ original_ev: EventBase,
+ redactions: Iterable[str],
+ event_map: Dict[str, EventBase],
+ ) -> Optional[EventBase]:
"""Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted
event.
Args:
- original_ev (EventBase):
- redactions (iterable[str]): list of event ids of potential redaction events
- event_map (dict[str, EventBase]): other events which have been fetched, in
- which we can look up the redaaction events. Map from event id to event.
+ original_ev: The original event.
+ redactions: list of event ids of potential redaction events
+ event_map: other events which have been fetched, in which we can
+ look up the redaaction events. Map from event id to event.
Returns:
- Deferred[EventBase|None]: if the event should be redacted, a pruned
- event object. Otherwise, None.
+ If the event should be redacted, a pruned event object. Otherwise, None.
"""
if original_ev.type == "m.room.create":
# we choose to ignore redactions of m.room.create events.
@@ -946,17 +978,17 @@ class EventsWorkerStore(SQLBaseStore):
row = txn.fetchone()
return row[0] if row else 0
- def get_current_state_event_counts(self, room_id):
+ async def get_current_state_event_counts(self, room_id: str) -> int:
"""
Gets the current number of state events in a room.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- Deferred[int]
+ The current number of state events.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn,
room_id,
@@ -991,7 +1023,9 @@ class EventsWorkerStore(SQLBaseStore):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
- def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+ async def get_all_new_forward_event_rows(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple]:
"""Returns new events, for the Events replication stream
Args:
@@ -999,7 +1033,7 @@ class EventsWorkerStore(SQLBaseStore):
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
- Returns: Deferred[List[Tuple]]
+ Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
@@ -1020,18 +1054,20 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
- def get_ex_outlier_stream_rows(self, last_id, current_id):
+ async def get_ex_outlier_stream_rows(
+ self, last_id: int, current_id: int
+ ) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
- Returns: Deferred[List[Tuple]]
+ Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
@@ -1054,7 +1090,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
)
@@ -1226,11 +1262,11 @@ class EventsWorkerStore(SQLBaseStore):
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
- def get_next_event_to_expire(self):
+ async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
table, or None if there's no more event to expire.
- Returns: Deferred[Optional[Tuple[str, int]]]
+ Returns:
A tuple containing the event ID as its first element and an expiry timestamp
as its second one, if there's at least one row in the event_expiry table.
None otherwise.
@@ -1246,6 +1282,6 @@ class EventsWorkerStore(SQLBaseStore):
return txn.fetchone()
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 45a1760170..d2f5b9a502 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
- def add_user_filter(self, user_localpart, user_filter):
+ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
@@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
return filter_id
- return self.db_pool.runInteraction("add_user_filter", _do_txn)
+ return await self.db_pool.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 3919ecad69..86557d5512 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe",
)
- def get_url_cache(self, url, ts):
+ async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
@@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
)
- return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
+ return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
@@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ async def update_cached_last_access_time(
+ self,
+ local_media: Iterable[str],
+ remote_media: Iterable[Tuple[str, str]],
+ time_ms: int,
+ ):
"""Updates the last access time of the given media
Args:
- local_media (iterable[str]): Set of media_ids
- remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ local_media: Set of media_ids
+ remote_media: Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
@@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
@@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
)
- def delete_remote_media(self, media_origin, media_id):
+ async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
def delete_remote_media_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_remote_media", delete_remote_media_txn
)
- def get_expired_url_cache(self, now_ts):
+ async def get_expired_url_cache(self, now_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository_url_cache"
" WHERE expires_ts < ?"
@@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)
@@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"delete_url_cache", _delete_url_cache_txn
)
- def get_url_cache_media_before(self, before_ts):
+ async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository"
" WHERE created_ts < ? AND url_cache IS NOT NULL"
@@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index 4db8949da7..2aac64901b 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
from synapse.storage._base import SQLBaseStore
@@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
desc="insert_open_id_token",
)
- def get_user_id_for_open_id_token(self, token, ts_now_ms):
+ async def get_user_id_for_open_id_token(
+ self, token: str, ts_now_ms: int
+ ) -> Optional[str]:
def get_user_id_for_token_txn(txn):
sql = (
"SELECT user_id FROM open_id_tokens"
@@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 301875a672..d2e0685e9e 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore):
desc="delete_remote_profile_cache",
)
- def get_remote_profile_cache_entries_that_expire(self, last_checked):
+ async def get_remote_profile_cache_entries_that_expire(
+ self, last_checked: int
+ ) -> Dict[str, str]:
"""Get all users who haven't been checked since `last_checked`
"""
@@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db_pool.cursor_to_dict(txn)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3526b6fd66..ea833829ae 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Tuple
+from typing import Any, List, Set, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
@@ -25,25 +25,24 @@ logger = logging.getLogger(__name__)
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
- def purge_history(self, room_id, token, delete_local_events):
+ async def purge_history(
+ self, room_id: str, token: str, delete_local_events: bool
+ ) -> Set[int]:
"""Deletes room history before a certain point
Args:
- room_id (str):
-
- token (str): A topological token to delete events before
-
- delete_local_events (bool):
+ room_id:
+ token: A topological token to delete events before
+ delete_local_events:
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
Returns:
- Deferred[set[int]]: The set of state groups that are referenced by
- deleted events.
+ The set of state groups that are referenced by deleted events.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
@@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
return referenced_state_groups
- def purge_room(self, room_id):
+ async def purge_room(self, room_id: str) -> List[int]:
"""Deletes all record of a room
Args:
- room_id (str)
+ room_id
Returns:
- Deferred[List[int]]: The list of state groups to delete.
+ The list of state groups to delete.
"""
-
- return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
+ return await self.db_pool.runInteraction(
+ "purge_room", self._purge_room_txn, room_id
+ )
def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 2fb5b02d7d..0de802a86b 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -18,8 +18,6 @@ import abc
import logging
from typing import List, Tuple, Union
-from twisted.internet import defer
-
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -149,9 +147,11 @@ class PushRulesWorkerStore(
)
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
- def have_push_rules_changed_for_user(self, user_id, last_id):
+ async def have_push_rules_changed_for_user(
+ self, user_id: str, last_id: int
+ ) -> bool:
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
+ return False
else:
def have_push_rules_changed_txn(txn):
@@ -163,7 +163,7 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone()
return bool(count)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 436f22ad2d..4a0d5a320e 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
return results
- def get_users_sent_receipts_between(self, last_id: int, current_id: int):
+ async def get_users_sent_receipts_between(
+ self, last_id: int, current_id: int
+ ) -> List[str]:
"""Get all users who sent receipts between `last_id` exclusive and
`current_id` inclusive.
Returns:
- Deferred[List[str]]
+ The list of users.
"""
if last_id == current_id:
@@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [r[0] for r in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
)
@@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
return stream_id, max_persisted_id
- def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
- return self.db_pool.runInteraction(
+ async def insert_graph_receipt(
+ self, room_id, receipt_type, user_id, event_ids, data
+ ):
+ return await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 12689f4308..01f20c03c2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
import logging
import re
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -84,22 +84,22 @@ class RegistrationWorkerStore(SQLBaseStore):
return is_trial
@cached()
- def get_user_by_access_token(self, token):
+ async def get_user_by_access_token(self, token: str) -> Optional[dict]:
"""Get a user from the given access token.
Args:
- token (str): The access token of a user.
+ token: The access token of a user.
Returns:
- defer.Deferred: None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`,
- `valid_until_ms`.
+ None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`,
+ `valid_until_ms`.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@cached()
- async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
+ async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]:
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
@@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
return bool(res) if res else False
- def set_server_admin(self, user, admin):
+ async def set_server_admin(self, user: UserID, admin: bool) -> None:
"""Sets whether a user is an admin of this homeserver.
Args:
- user (UserID): user ID of the user to test
- admin (bool): true iff the user is to be a server admin,
- false otherwise.
+ user: user ID of the user to test
+ admin: true iff the user is to be a server admin, false otherwise.
"""
def set_server_admin_txn(txn):
@@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_user_by_id, (user.to_string(),)
)
- return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+ await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
@@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return True if res == UserTypes.SUPPORT else False
- def get_users_by_id_case_insensitive(self, user_id):
+ async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
"""Gets users that match user_id case insensitively.
- Returns a mapping of user_id -> password_hash.
+
+ Returns:
+ A mapping of user_id -> password_hash.
"""
def f(txn):
@@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
- return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
+ return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("count_users", _count_users)
- def count_daily_user_type(self):
+ async def count_daily_user_type(self) -> Dict[str, int]:
"""
Counts 1) native non guest users
2) native guests users
@@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_daily_user_type", _count_daily_user_type
)
@@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
# Convert the integer into a boolean.
return res == 1
- def get_threepid_validation_session(
- self, medium, client_secret, address=None, sid=None, validated=True
- ):
+ async def get_threepid_validation_session(
+ self,
+ medium: Optional[str],
+ client_secret: str,
+ address: Optional[str] = None,
+ sid: Optional[str] = None,
+ validated: Optional[bool] = True,
+ ) -> Optional[Dict[str, Any]]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
Args:
- medium (str|None): The medium of the 3PID
- address (str|None): The address of the 3PID
- sid (str|None): The ID of the validation session
- client_secret (str): A unique string provided by the client to help identify this
+ medium: The medium of the 3PID
+ client_secret: A unique string provided by the client to help identify this
validation attempt
- validated (bool|None): Whether sessions should be filtered by
+ address: The address of the 3PID
+ sid: The ID of the validation session
+ validated: Whether sessions should be filtered by
whether they have been validated already or not. None to
perform no filtering
Returns:
- Deferred[dict|None]: A dict containing the following:
+ A dict containing the following:
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
@@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
- def delete_threepid_session(self, session_id):
+ async def delete_threepid_session(self, session_id: str) -> None:
"""Removes a threepid validation session from the database. This can
be done after validation has been performed and whatever action was
waiting on it has been carried out
Args:
- session_id (str): The ID of the session to delete
+ session_id: The ID of the session to delete
"""
def delete_threepid_session_txn(txn):
@@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
keyvalues={"session_id": session_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
@@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
- def register_user(
+ async def register_user(
self,
- user_id,
- password_hash=None,
- was_guest=False,
- make_guest=False,
- appservice_id=None,
- create_profile_with_displayname=None,
- admin=False,
- user_type=None,
- shadow_banned=False,
- ):
+ user_id: str,
+ password_hash: Optional[str] = None,
+ was_guest: bool = False,
+ make_guest: bool = False,
+ appservice_id: Optional[str] = None,
+ create_profile_with_displayname: Optional[str] = None,
+ admin: bool = False,
+ user_type: Optional[str] = None,
+ shadow_banned: bool = False,
+ ) -> None:
"""Attempts to register an account.
Args:
- user_id (str): The desired user ID to register.
- password_hash (str|None): Optional. The password hash for this user.
- was_guest (bool): Optional. Whether this is a guest account being
- upgraded to a non-guest account.
- make_guest (boolean): True if the the new user should be guest,
- false to add a regular user account.
- appservice_id (str): The ID of the appservice registering the user.
- create_profile_with_displayname (unicode): Optionally create a profile for
+ user_id: The desired user ID to register.
+ password_hash: Optional. The password hash for this user.
+ was_guest: Whether this is a guest account being upgraded to a
+ non-guest account.
+ make_guest: True if the the new user should be guest, false to add a
+ regular user account.
+ appservice_id: The ID of the appservice registering the user.
+ create_profile_with_displayname: Optionally create a profile for
the user, setting their displayname to the given value
- admin (boolean): is an admin user?
- user_type (str|None): type of user. One of the values from
- api.constants.UserTypes, or None for a normal user.
- shadow_banned (bool): Whether the user is shadow-banned,
- i.e. they may be told their requests succeeded but we ignore them.
+ admin: is an admin user?
+ user_type: type of user. One of the values from api.constants.UserTypes,
+ or None for a normal user.
+ shadow_banned: Whether the user is shadow-banned, i.e. they may be
+ told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
-
- Returns:
- Deferred
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"register_user",
self._register_user,
user_id,
@@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
- def user_set_password_hash(self, user_id, password_hash):
+ async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
- def user_set_consent_version(self, user_id, consent_version):
+ async def user_set_consent_version(
+ self, user_id: str, consent_version: str
+ ) -> None:
"""Updates the user table to record privacy policy consent
Args:
- user_id (str): full mxid of the user to update
- consent_version (str): version of the policy the user has consented
- to
+ user_id: full mxid of the user to update
+ consent_version: version of the policy the user has consented to
Raises:
StoreError(404) if user not found
@@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db_pool.runInteraction("user_set_consent_version", f)
+ await self.db_pool.runInteraction("user_set_consent_version", f)
- def user_set_consent_server_notice_sent(self, user_id, consent_version):
+ async def user_set_consent_server_notice_sent(
+ self, user_id: str, consent_version: str
+ ) -> None:
"""Updates the user table to record that we have sent the user a server
notice about privacy policy consent
Args:
- user_id (str): full mxid of the user to update
- consent_version (str): version of the policy we have notified the
- user about
+ user_id: full mxid of the user to update
+ consent_version: version of the policy we have notified the user about
Raises:
StoreError(404) if user not found
@@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
+ await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
- def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
+ async def user_delete_access_tokens(
+ self,
+ user_id: str,
+ except_token_id: Optional[str] = None,
+ device_id: Optional[str] = None,
+ ) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access tokens belonging to a user
Args:
- user_id (str): ID of user the tokens belong to
- except_token_id (str): list of access_tokens IDs which should
- *not* be deleted
- device_id (str|None): ID of device the tokens are associated with.
+ user_id: ID of user the tokens belong to
+ except_token_id: access_tokens ID which should *not* be deleted
+ device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
- defer.Deferred[list[str, int, str|None, int]]: a list of
- (token, token id, device id) for each of the deleted tokens
+ A tuple of (token, token id, device id) for each of the deleted tokens
"""
def f(txn):
@@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
- return self.db_pool.runInteraction("user_delete_access_tokens", f)
+ return await self.db_pool.runInteraction("user_delete_access_tokens", f)
- def delete_access_token(self, access_token):
+ async def delete_access_token(self, access_token: str) -> None:
def f(txn):
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
@@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
- return self.db_pool.runInteraction("delete_access_token", f)
+ await self.db_pool.runInteraction("delete_access_token", f)
@cached()
async def is_guest(self, user_id: str) -> bool:
@@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="get_users_pending_deactivation",
)
- def validate_threepid_session(self, session_id, client_secret, token, current_ts):
+ async def validate_threepid_session(
+ self, session_id: str, client_secret: str, token: str, current_ts: int
+ ) -> Optional[str]:
"""Attempt to validate a threepid session using a token
Args:
- session_id (str): The id of a validation session
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- token (str): A validation token
- current_ts (int): The current unix time in milliseconds. Used for
- checking token expiry status
+ session_id: The id of a validation session
+ client_secret: A unique string provided by the client to help identify
+ this validation attempt
+ token: A validation token
+ current_ts: The current unix time in milliseconds. Used for checking
+ token expiry status
Raises:
ThreepidValidationError: if a matching validation token was not found or has
expired
Returns:
- deferred str|None: A str representing a link to redirect the user
- to if there is one.
+ A str representing a link to redirect the user to if there is one.
"""
# Insert everything into a transaction in order to run atomically
@@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
- def start_or_continue_validation_session(
+ async def start_or_continue_validation_session(
self,
- medium,
- address,
- session_id,
- client_secret,
- send_attempt,
- next_link,
- token,
- token_expires,
- ):
+ medium: str,
+ address: str,
+ session_id: str,
+ client_secret: str,
+ send_attempt: int,
+ next_link: Optional[str],
+ token: str,
+ token_expires: int,
+ ) -> None:
"""Creates a new threepid validation session if it does not already
exist and associates a new validation token with it
Args:
- medium (str): The medium of the 3PID
- address (str): The address of the 3PID
- session_id (str): The id of this validation session
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- send_attempt (int): The latest send_attempt on this session
- next_link (str|None): The link to redirect the user to upon
- successful validation
- token (str): The validation token
- token_expires (int): The timestamp for which after the token
- will no longer be valid
+ medium: The medium of the 3PID
+ address: The address of the 3PID
+ session_id: The id of this validation session
+ client_secret: A unique string provided by the client to help
+ identify this validation attempt
+ send_attempt: The latest send_attempt on this session
+ next_link: The link to redirect the user to upon successful validation
+ token: The validation token
+ token_expires: The timestamp for which after the token will no
+ longer be valid
"""
def start_or_continue_validation_session_txn(txn):
@@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
- def cull_expired_threepid_validation_tokens(self):
+ async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
def cull_expired_threepid_validation_tokens_txn(txn, ts):
@@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
DELETE FROM threepid_validation_token WHERE
expires < ?
"""
- return txn.execute(sql, (ts,))
+ txn.execute(sql, (ts,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index a9ceffc20e..5cd61547f7 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
- def get_relations_for_event(
+ async def get_relations_for_event(
self,
- event_id,
- relation_type=None,
- event_type=None,
- aggregation_key=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
+ event_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ aggregation_key: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[RelationPaginationToken] = None,
+ to_token: Optional[RelationPaginationToken] = None,
+ ) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
- event_id (str): Fetch events that relate to this event ID.
- relation_type (str|None): Only fetch events with this relation
- type, if given.
- event_type (str|None): Only fetch events with this event type, if
- given.
- aggregation_key (str|None): Only fetch events with this aggregation
- key, if given.
- limit (int): Only fetch the most recent `limit` events.
- direction (str): Whether to fetch the most recent first (`"b"`) or
- the oldest first (`"f"`).
- from_token (RelationPaginationToken|None): Fetch rows from the given
- token, or from the start if None.
- to_token (RelationPaginationToken|None): Fetch rows up to the given
- token, or up to the end if None.
+ event_id: Fetch events that relate to this event ID.
+ relation_type: Only fetch events with this relation type, if given.
+ event_type: Only fetch events with this event type, if given.
+ aggregation_key: Only fetch events with this aggregation key, if given.
+ limit: Only fetch the most recent `limit` events.
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`).
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- Deferred[PaginationChunk]: List of event IDs that match relations
- requested. The rows are of the form `{"event_id": "..."}`.
+ List of event IDs that match relations requested. The rows are of
+ the form `{"event_id": "..."}`.
"""
where_clause = ["relates_to_id = ?"]
@@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
@cached(tree=True)
- def get_aggregation_groups_for_event(
+ async def get_aggregation_groups_for_event(
self,
- event_id,
- event_type=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
+ event_id: str,
+ event_type: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[AggregationPaginationToken] = None,
+ to_token: Optional[AggregationPaginationToken] = None,
+ ) -> PaginationChunk:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
on an event.
Args:
- event_id (str): Fetch events that relate to this event ID.
- event_type (str|None): Only fetch events with this event type, if
- given.
- limit (int): Only fetch the `limit` groups.
- direction (str): Whether to fetch the highest count first (`"b"`) or
+ event_id: Fetch events that relate to this event ID.
+ event_type: Only fetch events with this event type, if given.
+ limit: Only fetch the `limit` groups.
+ direction: Whether to fetch the highest count first (`"b"`) or
the lowest count first (`"f"`).
- from_token (AggregationPaginationToken|None): Fetch rows from the
- given token, or from the start if None.
- to_token (AggregationPaginationToken|None): Fetch rows up to the
- given token, or up to the end if None.
-
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- Deferred[PaginationChunk]: List of groups of annotations that
- match. Each row is a dict with `type`, `key` and `count` fields.
+ List of groups of annotations that match. Each row is a dict with
+ `type`, `key` and `count` fields.
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
@@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
@@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
return await self.get_event(edit_id, allow_none=True)
- def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+ async def has_user_annotated_event(
+ self, parent_id: str, event_type: str, aggregation_key: str, sender: str
+ ) -> bool:
"""Check if a user has already annotated an event with the same key
(e.g. already liked an event).
Args:
- parent_id (str): The event being annotated
- event_type (str): The event type of the annotation
- aggregation_key (str): The aggregation key of the annotation
- sender (str): The sender of the annotation
+ parent_id: The event being annotated
+ event_type: The event type of the annotation
+ aggregation_key: The aggregation key of the annotation
+ sender: The sender of the annotation
Returns:
- Deferred[bool]
+ True if the event is already annotated.
"""
sql = """
@@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
return bool(txn.fetchone())
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a92641c339..717df97301 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
allow_none=True,
)
- def get_room_with_stats(self, room_id: str):
+ async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve room with statistics.
Args:
@@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
res["public"] = bool(res["public"])
return res
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
)
@@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
desc="get_public_room_ids",
)
- def count_public_rooms(self, network_tuple, ignore_non_federatable):
+ async def count_public_rooms(
+ self,
+ network_tuple: Optional[ThirdPartyInstanceID],
+ ignore_non_federatable: bool,
+ ) -> int:
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
Args:
- network_tuple (ThirdPartyInstanceID|None)
- ignore_non_federatable (bool): If true filters out non-federatable rooms
+ network_tuple
+ ignore_non_federatable: If true filters out non-federatable rooms
"""
def _count_public_rooms_txn(txn):
@@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args)
return txn.fetchone()[0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
)
@@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore):
return row
- def get_media_mxcs_in_room(self, room_id):
+ async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
- room_id (str)
+ room_id
Returns:
- The local and remote media as a lists of tuples where the key is
- the hostname and the value is the media ID.
+ The local and remote media as a lists of the media IDs.
"""
def _get_media_mxcs_in_room_txn(txn):
@@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
- def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+ async def quarantine_media_ids_in_room(
+ self, room_id: str, quarantined_by: str
+ ) -> int:
"""For a room loops through all events with media and quarantines
the associated media
"""
@@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- def quarantine_media_by_id(
+ async def quarantine_media_by_id(
self, server_name: str, media_id: str, quarantined_by: str,
- ):
+ ) -> int:
"""quarantines a single local or remote media id
Args:
@@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_id_txn
)
- def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+ async def quarantine_media_ids_by_user(
+ self, user_id: str, quarantined_by: str
+ ) -> int:
"""quarantines all local media associated with a single user
Args:
@@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
@@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
- def get_room_count(self):
- """Retrieve a list of all rooms
+ async def get_room_count(self) -> int:
+ """Retrieve the total number of rooms.
"""
def f(txn):
@@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
- return self.db_pool.runInteraction("get_rooms", f)
+ return await self.db_pool.runInteraction("get_rooms", f)
async def add_event_report(
self,
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 161edbeccb..c46f5cd524 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000, iterable=True)
- def get_users_in_room(self, room_id: str):
- return self.db_pool.runInteraction(
+ async def get_users_in_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
@@ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn]
@cached(max_entries=100000)
- def get_room_summary(self, room_id: str):
+ async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id: The room ID to query
Returns:
- Deferred[dict[str, MemberSummary]:
- dict of membership states, pointing to a MemberSummary named tuple.
+ dict of membership states, pointing to a MemberSummary named tuple.
"""
def _get_room_summary_txn(txn):
@@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res
- return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
+ return await self.db_pool.runInteraction(
+ "get_room_summary", _get_room_summary_txn
+ )
@cached()
- def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
+ async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
"""Get all the rooms the *local* user is invited to.
Args:
user_id: The user ID.
Returns:
- A awaitable list of RoomsForUser.
+ A list of RoomsForUser.
"""
- return self.get_rooms_for_local_user_where_membership_is(
+ return await self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
@@ -297,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
- self, user_id: str, membership_list: List[str]
- ) -> Optional[List[RoomsForUser]]:
+ self, user_id: str, membership_list: Collection[str]
+ ) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -313,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
- return None
+ return []
rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",
@@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(max_entries=500000, iterable=True)
- def get_rooms_for_user_with_stream_ordering(self, user_id: str):
+ async def get_rooms_for_user_with_stream_ordering(
+ self, user_id: str
+ ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
@@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id
Returns:
- Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
- the rooms the user is in currently, along with the stream ordering
- of the most recent join for that user and room.
+ Returns the rooms the user is in currently, along with the stream
+ ordering of the most recent join for that user and room.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn,
user_id,
)
- def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
+ def _get_rooms_for_user_with_stream_ordering_txn(
+ self, txn, user_id: str
+ ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
@@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
- results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
-
- return results
+ return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
@@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return count == 0
@cached()
- def get_forgotten_rooms_for_user(self, user_id: str):
+ async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
"""Gets all rooms the user has forgotten.
Args:
- user_id
+ user_id: The user ID to query the rooms of.
Returns:
- Deferred[set[str]]
+ The forgotten rooms.
"""
def _get_forgotten_rooms_for_user_txn(txn):
@@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,))
return {row[0] for row in txn if row[1] == 0}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
- def forget(self, user_id: str, room_id: str):
+ async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
- return self.db_pool.runInteraction("forget_membership", f)
+ await self.db_pool.runInteraction("forget_membership", f)
class _JoinedHostsCache(object):
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
new file mode 100644
index 0000000000..98ff76d709
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE events ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
new file mode 100644
index 0000000000..97c1e6a0c5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
+
+SELECT setval('events_stream_seq', (
+ SELECT COALESCE(MAX(stream_ordering), 1) FROM events
+));
+
+CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+
+SELECT setval('events_backfill_stream_seq', (
+ SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+));
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
new file mode 100644
index 0000000000..b451e8663a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're hijacking the push actions to store unread messages and unread counts (specified
+-- in MSC2654) because doing otherwise would result in either performance issues or
+-- reimplementing a consequent bit of the push actions.
+
+-- Add columns to event_push_actions and event_push_actions_staging to track unread
+-- messages and calculate unread counts.
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+
+-- Add column to event_push_summary
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index dcbdeab36e..9c5f0229c1 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,9 +16,10 @@
import logging
import re
from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Set
from synapse.api.errors import SynapseError
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count,
}
- def _find_highlights_in_postgres(self, search_query, events):
+ async def _find_highlights_in_postgres(
+ self, search_query: str, events: List[EventBase]
+ ) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlight the matching parts.
Args:
- search_query (str)
- events (list): A list of events
+ search_query
+ events: A list of events
Returns:
- deferred : A set of strings.
+ A set of strings.
"""
def f(txn):
@@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words
- return self.db_pool.runInteraction("_find_highlights", f)
+ return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index be191dd870..c8c67953e4 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Iterable, List, Tuple
+
from unpaddedbase64 import encode_base64
from synapse.storage._base import SQLBaseStore
+from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
@@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
- def get_event_reference_hashes(self, event_ids):
+ async def get_event_reference_hashes(
+ self, event_ids: Iterable[str]
+ ) -> Dict[str, Dict[str, bytes]]:
+ """Get all hashes for given events.
+
+ Args:
+ event_ids: The event IDs to get hashes for.
+
+ Returns:
+ A mapping of event ID to a mapping of algorithm to hash.
+ """
+
def f(txn):
return {
event_id: self._get_event_reference_hashes_txn(txn, event_id)
for event_id in event_ids
}
- return self.db_pool.runInteraction("get_event_reference_hashes", f)
+ return await self.db_pool.runInteraction("get_event_reference_hashes", f)
- async def add_event_hashes(self, event_ids):
+ async def add_event_hashes(
+ self, event_ids: Iterable[str]
+ ) -> List[Tuple[str, Dict[str, str]]]:
+ """
+
+ Args:
+ event_ids: The event IDs
+
+ Returns:
+ A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
+ """
hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
@@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
return list(hashes.items())
- def _get_event_reference_hashes_txn(self, txn, event_id):
+ def _get_event_reference_hashes_txn(
+ self, txn: Cursor, event_id: str
+ ) -> Dict[str, bytes]:
"""Get all the hashes for a given PDU.
Args:
- txn (cursor):
- event_id (str): Id for the Event.
+ txn:
+ event_id: Id for the Event.
Returns:
- A dict[unicode, bytes] of algorithm -> hash.
+ A mapping of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9b9bc304a8..55a250ef06 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -224,14 +224,32 @@ class StatsStore(StateDeltasStore):
)
async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
- """
+ """Update the state of a room.
+
+ fields can contain the following keys with string values:
+ * join_rules
+ * history_visibility
+ * encryption
+ * name
+ * topic
+ * avatar
+ * canonical_alias
+
+ A is_federatable key can also be included with a boolean value.
+
Args:
- room_id
- fields
+ room_id: The room ID to update the state of.
+ fields: The fields to update. This can include a partial list of the
+ above fields to only update some room information.
"""
-
- # For whatever reason some of the fields may contain null bytes, which
- # postgres isn't a fan of, so we replace those fields with null.
+ # Ensure that the values to update are valid, they should be strings and
+ # not contain any null bytes.
+ #
+ # Invalid data gets overwritten with null.
+ #
+ # Note that a missing value should not be overwritten (it keeps the
+ # previous value).
+ sentinel = object()
for col in (
"join_rules",
"history_visibility",
@@ -241,8 +259,8 @@ class StatsStore(StateDeltasStore):
"avatar",
"canonical_alias",
):
- field = fields.get(col)
- if field and "\0" in field:
+ field = fields.get(col, sentinel)
+ if field is not sentinel and (not isinstance(field, str) or "\0" in field):
fields[col] = None
await self.db_pool.simple_upsert(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 24f44a7e36..be6df8a6d1 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,7 +39,7 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from twisted.internet import defer
@@ -47,12 +47,19 @@ from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.types import RoomStreamToken
+from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -202,7 +209,7 @@ def _make_generic_sql_bound(
)
-def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
+def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -260,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@@ -293,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._stream_order_on_start = self.get_room_max_stream_ordering()
@abc.abstractmethod
- def get_room_max_stream_ordering(self):
+ def get_room_max_stream_ordering(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
- def get_room_min_stream_ordering(self):
+ def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()
async def get_room_events_stream_for_rooms(
self,
- room_ids: Iterable[str],
+ room_ids: Collection[str],
from_key: str,
to_key: str,
limit: int = 0,
@@ -356,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- def get_rooms_that_changed(self, room_ids, from_key):
+ def get_rooms_that_changed(
+ self, room_ids: Collection[str], from_key: str
+ ) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
- room_ids (list)
- from_key (str): The room_key portion of a StreamToken
+ room_ids
+ from_key: The room_key portion of a StreamToken
"""
- from_key = RoomStreamToken.parse_stream_token(from_key).stream
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
return {
room_id
for room_id in room_ids
- if self._events_stream_cache.has_entity_changed(room_id, from_key)
+ if self._events_stream_cache.has_entity_changed(room_id, from_id)
}
async def get_room_events_stream_for_room(
@@ -440,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
- async def get_membership_changes_for_user(self, user_id, from_key, to_key):
+ async def get_membership_changes_for_user(
+ self, user_id: str, from_key: str, to_key: str
+ ) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -593,8 +604,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A stream ID.
"""
- return await self.db_pool.simple_select_one_onecol(
- table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
+ return await self.db_pool.runInteraction(
+ "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
+ )
+
+ def get_stream_id_for_event_txn(
+ self, txn: LoggingTransaction, event_id: str, allow_none=False,
+ ) -> int:
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn=txn,
+ table="events",
+ keyvalues={"event_id": event_id},
+ retcol="stream_ordering",
+ allow_none=allow_none,
)
async def get_stream_token_for_event(self, event_id: str) -> str:
@@ -646,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return row[0][0] if row else 0
- def _get_max_topological_txn(self, txn, room_id):
+ def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
txn.execute(
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
@@ -719,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_events_around_txn(
self,
- txn,
+ txn: LoggingTransaction,
room_id: str,
event_id: str,
before_limit: int,
@@ -747,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
+ # This cannot happen as `allow_none=False`.
+ assert results is not None
+
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
@@ -856,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="update_federation_out_pos",
)
- def _reset_federation_positions_txn(self, txn) -> None:
+ def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@@ -895,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
GROUP BY type
"""
txn.execute(sql)
- min_positions = dict(txn) # Map from type -> min position
+ min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
# Ensure we do actually have some values here
assert set(min_positions) == {"federation", "events"}
@@ -922,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _paginate_room_events_txn(
self,
- txn,
+ txn: LoggingTransaction,
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 0c34bbf21a..96ffe26cc9 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -43,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
- tags_by_room = {}
+ tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
@@ -123,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags(
self, user_id: str, stream_id: int
- ) -> Dict[str, List[str]]:
+ ) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 9eef8e57c5..b89668d561 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore):
class UIAuthStore(UIAuthWorkerStore):
- def delete_old_ui_auth_sessions(self, expiration_time: int):
+ async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore):
This is an epoch time in milliseconds.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_old_ui_auth_sessions",
self._delete_old_ui_auth_sessions_txn,
expiration_time,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index a9f2e93614..f2f9a5799a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,7 +15,7 @@
import logging
import re
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@@ -365,10 +365,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False
- def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
+ async def update_profile_in_user_dir(
+ self, user_id: str, display_name: str, avatar_url: str
+ ) -> None:
"""
Update or add a user's profile in the user directory.
"""
+ # If the display name or avatar URL are unexpected types, overwrite them.
+ if not isinstance(display_name, str):
+ display_name = None
+ if not isinstance(avatar_url, str):
+ avatar_url = None
def _update_profile_in_user_dir_txn(txn):
new_entry = self.db_pool.simple_upsert_txn(
@@ -458,17 +465,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
- def add_users_who_share_private_room(self, room_id, user_id_tuples):
+ async def add_users_who_share_private_room(
+ self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
+ ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
- room_id (str)
- user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ room_id
+ user_id_tuples: iterable of 2-tuple of user IDs.
"""
def _add_users_who_share_room_txn(txn):
@@ -484,17 +493,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
- def add_users_in_public_rooms(self, room_id, user_ids):
+ async def add_users_in_public_rooms(
+ self, room_id: str, user_ids: Iterable[str]
+ ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
- room_id (str)
- user_ids (list[str])
+ room_id
+ user_ids
"""
def _add_users_in_public_rooms_txn(txn):
@@ -508,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
- def delete_all_from_user_dir(self):
+ async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory
"""
@@ -523,7 +534,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@@ -555,7 +566,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
- def remove_from_user_dir(self, user_id):
+ async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
@@ -578,7 +589,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"remove_from_user_dir", _remove_from_user_dir_txn
)
@@ -605,14 +616,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return user_ids
- def remove_user_who_share_room(self, user_id, room_id):
+ async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
"""
Deletes entries in the users_who_share_*_rooms table. The first
user should be a local user.
Args:
- user_id (str)
- room_id (str)
+ user_id
+ room_id
"""
def _remove_user_who_share_room_txn(txn):
@@ -632,7 +643,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
keyvalues={"user_id": user_id, "room_id": room_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
@@ -664,6 +675,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
+ @cached()
+ async def get_shared_rooms_for_users(
+ self, user_id: str, other_user_id: str
+ ) -> Set[str]:
+ """
+ Returns the rooms that a local user shares with another local or remote user.
+
+ Args:
+ user_id: The MXID of a local user
+ other_user_id: The MXID of the other user
+
+ Returns:
+ A set of room ID's that the users share.
+ """
+
+ def _get_shared_rooms_for_users_txn(txn):
+ txn.execute(
+ """
+ SELECT p1.room_id
+ FROM users_in_public_rooms as p1
+ INNER JOIN users_in_public_rooms as p2
+ ON p1.room_id = p2.room_id
+ AND p1.user_id = ?
+ AND p2.user_id = ?
+ UNION
+ SELECT room_id
+ FROM users_who_share_private_rooms
+ WHERE
+ user_id = ?
+ AND other_user_id = ?
+ """,
+ (user_id, other_user_id, user_id, other_user_id),
+ )
+ rows = self.db_pool.cursor_to_dict(txn)
+ return rows
+
+ rows = await self.db_pool.runInteraction(
+ "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+ )
+
+ return {row["room_id"] for row in rows}
+
async def get_user_directory_stream_pos(self) -> int:
return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index e3547e53b3..2f7c95fc74 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
class UserErasureStore(UserErasureWorkerStore):
- def mark_user_erased(self, user_id: str) -> None:
+ async def mark_user_erased(self, user_id: str) -> None:
"""Indicate that user_id wishes their message history to be erased.
Args:
@@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db_pool.runInteraction("mark_user_erased", f)
+ await self.db_pool.runInteraction("mark_user_erased", f)
- def mark_user_not_erased(self, user_id: str) -> None:
+ async def mark_user_not_erased(self, user_id: str) -> None:
"""Indicate that user_id is no longer erased.
Args:
@@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db_pool.runInteraction("mark_user_not_erased", f)
+ await self.db_pool.runInteraction("mark_user_not_erased", f)
|