diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d2ba4bd2fc..ae4bf1a54f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -158,8 +158,8 @@ class LoggingDatabaseConnection:
def commit(self) -> None:
self.conn.commit()
- def rollback(self, *args, **kwargs) -> None:
- self.conn.rollback(*args, **kwargs)
+ def rollback(self) -> None:
+ self.conn.rollback()
def __enter__(self) -> "Connection":
self.conn.__enter__()
@@ -244,12 +244,15 @@ class LoggingTransaction:
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
+ def fetchone(self) -> Optional[Tuple]:
+ return self.txn.fetchone()
+
+ def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
+ return self.txn.fetchmany(size=size)
+
def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()
- def fetchone(self) -> Tuple:
- return self.txn.fetchone()
-
def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
@@ -754,6 +757,7 @@ class DatabasePool:
Returns:
A list of dicts where the key is the column header.
"""
+ assert cursor.description is not None, "cursor.description was None"
col_headers = [intern(str(column[0])) for column in cursor.description]
results = [dict(zip(col_headers, row)) for row in cursor]
return results
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c128889bf9..309f1e865b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, dict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
- ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+ ) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Take a list of one time keys out of the database.
Args:
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..92e65aa640 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
+ async def count_daily_e2ee_messages(self):
+ """
+ Returns an estimate of the number of messages sent in the last day.
+
+ If it has been significantly less or more than one day since the last
+ call to this function, it will return None.
+ """
+
+ def _count_messages(txn):
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
+
+ async def count_daily_sent_e2ee_messages(self):
+ def _count_messages(txn):
+ # This is good enough as if you have silly characters in your own
+ # hostname then thats your own fault.
+ like_clause = "%:" + self.hs.hostname
+
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND sender LIKE ?
+ AND stream_ordering > ?
+ """
+
+ txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction(
+ "count_daily_sent_e2ee_messages", _count_messages
+ )
+
+ async def count_daily_active_e2ee_rooms(self):
+ def _count(txn):
+ sql = """
+ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction(
+ "count_daily_active_e2ee_rooms", _count
+ )
+
async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 0618b4387a..8405dd460f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -472,6 +472,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
+ async def record_user_external_id(
+ self, auth_provider: str, external_id: str, user_id: str
+ ) -> None:
+ """Record a mapping from an external user id to a mxid
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+ user_id: complete mxid that it is mapped to
+ """
+ await self.db_pool.simple_insert(
+ table="user_external_ids",
+ values={
+ "auth_provider": auth_provider,
+ "external_id": external_id,
+ "user_id": user_id,
+ },
+ desc="record_user_external_id",
+ )
+
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
) -> Optional[str]:
@@ -1400,26 +1420,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- async def record_user_external_id(
- self, auth_provider: str, external_id: str, user_id: str
- ) -> None:
- """Record a mapping from an external user id to a mxid
-
- Args:
- auth_provider: identifier for the remote auth provider
- external_id: id on that system
- user_id: complete mxid that it is mapped to
- """
- await self.db_pool.simple_insert(
- table="user_external_ids",
- values={
- "auth_provider": auth_provider,
- "external_id": external_id,
- "user_id": user_id,
- },
- desc="record_user_external_id",
- )
-
async def user_set_password_hash(
self, user_id: str, password_hash: Optional[str]
) -> None:
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 035f9ea6e9..d15ccfacde 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import platform
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
@@ -28,11 +27,8 @@ def create_engine(database_config) -> BaseDatabaseEngine:
return Sqlite3Engine(sqlite3, database_config)
if name == "psycopg2":
- # pypy requires psycopg2cffi rather than psycopg2
- if platform.python_implementation() == "PyPy":
- import psycopg2cffi as psycopg2 # type: ignore
- else:
- import psycopg2 # type: ignore
+ # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
+ import psycopg2 # type: ignore
return PostgresEngine(psycopg2, database_config)
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 5db0f0b520..b3d1834efb 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import platform
import struct
import threading
import typing
@@ -30,6 +31,11 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (None, ":memory:",)
+ if platform.python_implementation() == "PyPy":
+ # pypy's sqlite3 module doesn't handle bytearrays, convert them
+ # back to bytes.
+ database_module.register_adapter(bytearray, lambda array: bytes(array))
+
# The current max state_group, or None if we haven't looked
# in the DB yet.
self._current_state_group_id = None
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 566ea19bae..28bb2eb662 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -619,9 +619,9 @@ def _get_or_create_schema_state(
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
- current_version = int(row[0]) if row else None
- if current_version:
+ if row is not None:
+ current_version = int(row[0])
txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 9cadcba18f..17291c9d5e 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable, Iterator, List, Optional, Tuple
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from typing_extensions import Protocol
@@ -20,23 +20,44 @@ from typing_extensions import Protocol
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
+_Parameters = Union[Sequence[Any], Mapping[str, Any]]
+
class Cursor(Protocol):
- def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+ def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
...
- def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+ def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
...
- def fetchall(self) -> List[Tuple]:
+ def fetchone(self) -> Optional[Tuple]:
+ ...
+
+ def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
...
- def fetchone(self) -> Tuple:
+ def fetchall(self) -> List[Tuple]:
...
@property
- def description(self) -> Any:
- return None
+ def description(
+ self,
+ ) -> Optional[
+ Sequence[
+ # Note that this is an approximate typing based on sqlite3 and other
+ # drivers, and may not be entirely accurate.
+ Tuple[
+ str,
+ Optional[Any],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ ]
+ ]
+ ]:
+ ...
@property
def rowcount(self) -> int:
@@ -59,7 +80,7 @@ class Connection(Protocol):
def commit(self) -> None:
...
- def rollback(self, *args, **kwargs) -> None:
+ def rollback(self) -> None:
...
def __enter__(self) -> "Connection":
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 0ec4dc2918..e2b316a218 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
def get_next_id_txn(self, txn: Cursor) -> int:
txn.execute("SELECT nextval(?)", (self._sequence_name,))
- return txn.fetchone()[0]
+ fetch_res = txn.fetchone()
+ assert fetch_res is not None
+ return fetch_res[0]
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute(
@@ -147,7 +149,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute(
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
)
- last_value, is_called = txn.fetchone()
+ fetch_res = txn.fetchone()
+ assert fetch_res is not None
+ last_value, is_called = fetch_res
# If we have an associated stream check the stream_positions table.
max_in_stream_positions = None
|