diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 33c42cf95a..4d4643619f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -670,8 +670,8 @@ class DatabasePool:
Returns:
The result of func
"""
- after_callbacks = [] # type: List[_CallbackListEntry]
- exception_callbacks = [] # type: List[_CallbackListEntry]
+ after_callbacks: List[_CallbackListEntry] = []
+ exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
@@ -832,31 +832,16 @@ class DatabasePool:
self,
table: str,
values: Dict[str, Any],
- or_ignore: bool = False,
desc: str = "simple_insert",
- ) -> bool:
+ ) -> None:
"""Executes an INSERT query on the named table.
Args:
table: string giving the table name
values: dict of new column names and values for them
- or_ignore: bool stating whether an exception should be raised
- when a conflicting row already exists. If True, False will be
- returned by the function instead
desc: description of the transaction, for logging and metrics
-
- Returns:
- Whether the row was inserted or not. Only useful when `or_ignore` is True
"""
- try:
- await self.runInteraction(desc, self.simple_insert_txn, table, values)
- except self.engine.module.IntegrityError:
- # We have to do or_ignore flag at this layer, since we can't reuse
- # a cursor after we receive an error from the db.
- if not or_ignore:
- raise
- return False
- return True
+ await self.runInteraction(desc, self.simple_insert_txn, table, values)
@staticmethod
def simple_insert_txn(
@@ -907,7 +892,7 @@ class DatabasePool:
# The sort is to ensure that we don't rely on dictionary iteration
# order.
keys, vals = zip(
- *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+ *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i)
)
for k in keys:
@@ -930,7 +915,7 @@ class DatabasePool:
insertion_values: Optional[Dict[str, Any]] = None,
desc: str = "simple_upsert",
lock: bool = True,
- ) -> Optional[bool]:
+ ) -> bool:
"""
`lock` should generally be set to True (the default), but can be set
@@ -951,8 +936,8 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
lock: True to lock the table when doing the upsert.
Returns:
- Native upserts always return None. Emulated upserts return True if a
- new entry was created, False if an existing one was updated.
+ Returns True if a row was inserted or updated (i.e. if `values` is
+ not empty then this always returns True)
"""
insertion_values = insertion_values or {}
@@ -995,7 +980,7 @@ class DatabasePool:
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
lock: bool = True,
- ) -> Optional[bool]:
+ ) -> bool:
"""
Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
@@ -1008,16 +993,15 @@ class DatabasePool:
insertion_values: additional key/values to use only when inserting
lock: True to lock the table when doing the upsert.
Returns:
- Native upserts always return None. Emulated upserts return True if a
- new entry was created, False if an existing one was updated.
+ Returns True if a row was inserted or updated (i.e. if `values` is
+ not empty then this always returns True)
"""
insertion_values = insertion_values or {}
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
- self.simple_upsert_txn_native_upsert(
+ return self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
- return None
else:
return self.simple_upsert_txn_emulated(
txn,
@@ -1045,8 +1029,8 @@ class DatabasePool:
insertion_values: additional key/values to use only when inserting
lock: True to lock the table when doing the upsert.
Returns:
- Returns True if a new entry was created, False if an existing
- one was updated.
+ Returns True if a row was inserted or updated (i.e. if `values` is
+ not empty then this always returns True)
"""
insertion_values = insertion_values or {}
@@ -1086,11 +1070,10 @@ class DatabasePool:
txn.execute(sql, sqlargs)
if txn.rowcount > 0:
- # successfully updated at least one row.
- return False
+ return True
# We didn't find any existing rows, so insert a new one
- allvalues = {} # type: Dict[str, Any]
+ allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
@@ -1111,17 +1094,21 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
- ) -> None:
+ ) -> bool:
"""
- Use the native UPSERT functionality in recent PostgreSQL versions.
+ Use the native UPSERT functionality in PostgreSQL.
Args:
table: The table to upsert into
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
+
+ Returns:
+ Returns True if a row was inserted or updated (i.e. if `values` is
+ not empty then this always returns True)
"""
- allvalues = {} # type: Dict[str, Any]
+ allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues)
allvalues.update(insertion_values or {})
@@ -1140,6 +1127,8 @@ class DatabasePool:
)
txn.execute(sql, list(allvalues.values()))
+ return bool(txn.rowcount)
+
async def simple_upsert_many(
self,
table: str,
@@ -1257,7 +1246,7 @@ class DatabasePool:
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
- allnames = [] # type: List[str]
+ allnames: List[str] = []
allnames.extend(key_names)
allnames.extend(value_names)
@@ -1566,7 +1555,7 @@ class DatabasePool:
"""
keyvalues = keyvalues or {}
- results = [] # type: List[Dict[str, Any]]
+ results: List[Dict[str, Any]] = []
if not iterable:
return results
@@ -1978,7 +1967,7 @@ class DatabasePool:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
- arg_list = [] # type: List[Any]
+ arg_list: List[Any] = []
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())
|