summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2023-11-08 07:45:34 -0500
committerPatrick Cloke <patrickc@matrix.org>2023-11-08 07:45:34 -0500
commitb77c9c3f735cfaf7ecab624a72d02dfe302dbe90 (patch)
tree03ec9900056e384a2a4d0ea0a9ae11ef533b0059
parentMerge remote-tracking branch 'origin/develop' into matrix-org-hotfixes (diff)
parentAvoid updating the same rows multiple times with simple_update_many_txn. (#16... (diff)
downloadsynapse-b77c9c3f735cfaf7ecab624a72d02dfe302dbe90.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
-rw-r--r--changelog.d/16532.misc1
-rw-r--r--changelog.d/16583.misc1
-rw-r--r--changelog.d/16590.misc1
-rw-r--r--changelog.d/16596.misc1
-rw-r--r--changelog.d/16605.misc1
-rw-r--r--changelog.d/16609.bugfix1
-rw-r--r--mypy.ini4
-rw-r--r--poetry.lock58
-rw-r--r--pyproject.toml4
-rw-r--r--synapse/metrics/_reactor_metrics.py130
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py56
-rw-r--r--synapse/storage/database.py50
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/events.py12
-rw-r--r--synapse/storage/databases/main/push_rule.py50
-rw-r--r--synapse/storage/databases/main/room.py2
-rw-r--r--synapse/storage/databases/main/search.py4
-rw-r--r--synapse/util/async_helpers.py14
-rw-r--r--tests/storage/test_base.py631
19 files changed, 879 insertions, 144 deletions
diff --git a/changelog.d/16532.misc b/changelog.d/16532.misc
new file mode 100644

index 0000000000..437e00210b --- /dev/null +++ b/changelog.d/16532.misc
@@ -0,0 +1 @@ +Support reactor tick timings on more types of event loops. diff --git a/changelog.d/16583.misc b/changelog.d/16583.misc new file mode 100644
index 0000000000..df5b27b112 --- /dev/null +++ b/changelog.d/16583.misc
@@ -0,0 +1 @@ +Avoid executing no-op queries. diff --git a/changelog.d/16590.misc b/changelog.d/16590.misc new file mode 100644
index 0000000000..6db04b0c98 --- /dev/null +++ b/changelog.d/16590.misc
@@ -0,0 +1 @@ +Run push rule evaluator setup in parallel. diff --git a/changelog.d/16596.misc b/changelog.d/16596.misc new file mode 100644
index 0000000000..fa457b12e5 --- /dev/null +++ b/changelog.d/16596.misc
@@ -0,0 +1 @@ +Improve tests of the SQL generator. diff --git a/changelog.d/16605.misc b/changelog.d/16605.misc new file mode 100644
index 0000000000..2db7da5692 --- /dev/null +++ b/changelog.d/16605.misc
@@ -0,0 +1 @@ +Bump setuptools-rust from 1.8.0 to 1.8.1. diff --git a/changelog.d/16609.bugfix b/changelog.d/16609.bugfix new file mode 100644
index 0000000000..a52d395cd3 --- /dev/null +++ b/changelog.d/16609.bugfix
@@ -0,0 +1 @@ +Fix a long-standing bug where some queries updated the same row twice. Introduced in Synapse 1.57.0. diff --git a/mypy.ini b/mypy.ini
index fdfe9432fc..1a2b9ea410 100644 --- a/mypy.ini +++ b/mypy.ini
@@ -37,8 +37,8 @@ files = build_rust.py [mypy-synapse.metrics._reactor_metrics] -# This module imports select.epoll. That exists on Linux, but doesn't on macOS. -# See https://github.com/matrix-org/synapse/pull/11771. +# This module pokes at the internals of OS-specific classes, to appease mypy +# on different systems we add additional ignores. warn_unused_ignores = False [mypy-synapse.util.caches.treecache] diff --git a/poetry.lock b/poetry.lock
index 334005241e..41556635d3 100644 --- a/poetry.lock +++ b/poetry.lock
@@ -2439,28 +2439,28 @@ files = [ [[package]] name = "ruff" -version = "0.0.292" -description = "An extremely fast Python linter, written in Rust." +version = "0.1.4" +description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.0.292-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:02f29db018c9d474270c704e6c6b13b18ed0ecac82761e4fcf0faa3728430c96"}, - {file = "ruff-0.0.292-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:69654e564342f507edfa09ee6897883ca76e331d4bbc3676d8a8403838e9fade"}, - {file = "ruff-0.0.292-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c3c91859a9b845c33778f11902e7b26440d64b9d5110edd4e4fa1726c41e0a4"}, - {file = "ruff-0.0.292-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f4476f1243af2d8c29da5f235c13dca52177117935e1f9393f9d90f9833f69e4"}, - {file = "ruff-0.0.292-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be8eb50eaf8648070b8e58ece8e69c9322d34afe367eec4210fdee9a555e4ca7"}, - {file = "ruff-0.0.292-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:9889bac18a0c07018aac75ef6c1e6511d8411724d67cb879103b01758e110a81"}, - {file = "ruff-0.0.292-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6bdfabd4334684a4418b99b3118793f2c13bb67bf1540a769d7816410402a205"}, - {file = "ruff-0.0.292-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7c77c53bfcd75dbcd4d1f42d6cabf2485d2e1ee0678da850f08e1ab13081a8"}, - {file = "ruff-0.0.292-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e087b24d0d849c5c81516ec740bf4fd48bf363cfb104545464e0fca749b6af9"}, - {file = "ruff-0.0.292-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f160b5ec26be32362d0774964e218f3fcf0a7da299f7e220ef45ae9e3e67101a"}, - {file = "ruff-0.0.292-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ac153eee6dd4444501c4bb92bff866491d4bfb01ce26dd2fff7ca472c8df9ad0"}, - {file = "ruff-0.0.292-py3-none-musllinux_1_2_i686.whl", hash = "sha256:87616771e72820800b8faea82edd858324b29bb99a920d6aa3d3949dd3f88fb0"}, - {file = "ruff-0.0.292-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b76deb3bdbea2ef97db286cf953488745dd6424c122d275f05836c53f62d4016"}, - {file = "ruff-0.0.292-py3-none-win32.whl", hash = "sha256:e854b05408f7a8033a027e4b1c7f9889563dd2aca545d13d06711e5c39c3d003"}, - {file = "ruff-0.0.292-py3-none-win_amd64.whl", hash = "sha256:f27282bedfd04d4c3492e5c3398360c9d86a295be00eccc63914438b4ac8a83c"}, - {file = "ruff-0.0.292-py3-none-win_arm64.whl", hash = "sha256:7f67a69c8f12fbc8daf6ae6d36705037bde315abf8b82b6e1f4c9e74eb750f68"}, - {file = "ruff-0.0.292.tar.gz", hash = "sha256:1093449e37dd1e9b813798f6ad70932b57cf614e5c2b5c51005bf67d55db33ac"}, + {file = "ruff-0.1.4-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:864958706b669cce31d629902175138ad8a069d99ca53514611521f532d91495"}, + {file = "ruff-0.1.4-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:9fdd61883bb34317c788af87f4cd75dfee3a73f5ded714b77ba928e418d6e39e"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4eaca8c9cc39aa7f0f0d7b8fe24ecb51232d1bb620fc4441a61161be4a17539"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9a1301dc43cbf633fb603242bccd0aaa34834750a14a4c1817e2e5c8d60de17"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78e8db8ab6f100f02e28b3d713270c857d370b8d61871d5c7d1702ae411df683"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:80fea754eaae06335784b8ea053d6eb8e9aac75359ebddd6fee0858e87c8d510"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6bc02a480d4bfffd163a723698da15d1a9aec2fced4c06f2a753f87f4ce6969c"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9862811b403063765b03e716dac0fda8fdbe78b675cd947ed5873506448acea4"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58826efb8b3efbb59bb306f4b19640b7e366967a31c049d49311d9eb3a4c60cb"}, + {file = "ruff-0.1.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:fdfd453fc91d9d86d6aaa33b1bafa69d114cf7421057868f0b79104079d3e66e"}, + {file = "ruff-0.1.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e8791482d508bd0b36c76481ad3117987301b86072158bdb69d796503e1c84a8"}, + {file = "ruff-0.1.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:01206e361021426e3c1b7fba06ddcb20dbc5037d64f6841e5f2b21084dc51800"}, + {file = "ruff-0.1.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:645591a613a42cb7e5c2b667cbefd3877b21e0252b59272ba7212c3d35a5819f"}, + {file = "ruff-0.1.4-py3-none-win32.whl", hash = "sha256:99908ca2b3b85bffe7e1414275d004917d1e0dfc99d497ccd2ecd19ad115fd0d"}, + {file = "ruff-0.1.4-py3-none-win_amd64.whl", hash = "sha256:1dfd6bf8f6ad0a4ac99333f437e0ec168989adc5d837ecd38ddb2cc4a2e3db8a"}, + {file = "ruff-0.1.4-py3-none-win_arm64.whl", hash = "sha256:d98ae9ebf56444e18a3e3652b3383204748f73e247dea6caaf8b52d37e6b32da"}, + {file = "ruff-0.1.4.tar.gz", hash = "sha256:21520ecca4cc555162068d87c747b8f95e1e95f8ecfcbbe59e8dd00710586315"}, ] [[package]] @@ -2580,13 +2580,13 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( [[package]] name = "setuptools-rust" -version = "1.8.0" +version = "1.8.1" description = "Setuptools Rust extension plugin" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-rust-1.8.0.tar.gz", hash = "sha256:5e02b7a80058853bf64127314f6b97d0efed11e08b94c88ca639a20976f6adc4"}, - {file = "setuptools_rust-1.8.0-py3-none-any.whl", hash = "sha256:95ec67edee2ca73233c9e75250e9d23a302aa23b4c8413dfd19c14c30d08f703"}, + {file = "setuptools-rust-1.8.1.tar.gz", hash = "sha256:94b1dd5d5308b3138d5b933c3a2b55e6d6927d1a22632e509fcea9ddd0f7e486"}, + {file = "setuptools_rust-1.8.1-py3-none-any.whl", hash = "sha256:b5324493949ccd6aa0c03890c5f6b5f02de4512e3ac1697d02e9a6c02b18aa8e"}, ] [package.dependencies] @@ -3069,13 +3069,13 @@ files = [ [[package]] name = "types-jsonschema" -version = "4.19.0.3" +version = "4.19.0.4" description = "Typing stubs for jsonschema" optional = false python-versions = ">=3.8" files = [ - {file = "types-jsonschema-4.19.0.3.tar.gz", hash = "sha256:e0fc0f5d51fd0988bf193be42174a5376b0096820ff79505d9c1b66de23f0581"}, - {file = "types_jsonschema-4.19.0.3-py3-none-any.whl", hash = "sha256:5cedbb661e5ca88d95b94b79902423e3f97a389c245e5fe0ab384122f27d56b9"}, + {file = "types-jsonschema-4.19.0.4.tar.gz", hash = "sha256:994feb6632818259c4b5dbd733867824cb475029a6abc2c2b5201a2268b6e7d2"}, + {file = "types_jsonschema-4.19.0.4-py3-none-any.whl", hash = "sha256:b73c3f4ba3cd8108602d1198a438e2698d5eb6b9db206ed89a33e24729b0abe7"}, ] [package.dependencies] @@ -3141,13 +3141,13 @@ cryptography = ">=35.0.0" [[package]] name = "types-pyyaml" -version = "6.0.12.11" +version = "6.0.12.12" description = "Typing stubs for PyYAML" optional = false python-versions = "*" files = [ - {file = "types-PyYAML-6.0.12.11.tar.gz", hash = "sha256:7d340b19ca28cddfdba438ee638cd4084bde213e501a3978738543e27094775b"}, - {file = "types_PyYAML-6.0.12.11-py3-none-any.whl", hash = "sha256:a461508f3096d1d5810ec5ab95d7eeecb651f3a15b71959999988942063bf01d"}, + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, ] [[package]] @@ -3447,4 +3447,4 @@ user-search = ["pyicu"] [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "a08543c65f18cc7e9dea648e89c18ab88fc1747aa2e029aa208f777fc3db06dd" +content-hash = "369455d6a67753a6bcfbad3cd86801b1dd02896d0180080e2ba9501e007353ec" diff --git a/pyproject.toml b/pyproject.toml
index 23e0004395..df132c0236 100644 --- a/pyproject.toml +++ b/pyproject.toml
@@ -321,7 +321,7 @@ all = [ # This helps prevents merge conflicts when running a batch of dependabot updates. isort = ">=5.10.1" black = ">=22.7.0" -ruff = "0.0.292" +ruff = "0.1.4" # Type checking only works with the pydantic.v1 compat module from pydantic v2 pydantic = "^2" @@ -381,7 +381,7 @@ furo = ">=2022.12.7,<2024.0.0" # system changes. # We are happy to raise these upper bounds upon request, # provided we check that it's safe to do so (i.e. that CI passes). -requires = ["poetry-core>=1.1.0,<=1.7.0", "setuptools_rust>=1.3,<=1.8.0"] +requires = ["poetry-core>=1.1.0,<=1.7.0", "setuptools_rust>=1.3,<=1.8.1"] build-backend = "poetry.core.masonry.api" diff --git a/synapse/metrics/_reactor_metrics.py b/synapse/metrics/_reactor_metrics.py
index a2c6e6842d..dd486dd3e2 100644 --- a/synapse/metrics/_reactor_metrics.py +++ b/synapse/metrics/_reactor_metrics.py
@@ -12,17 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -import select +import logging import time -from typing import Any, Iterable, List, Tuple +from selectors import SelectSelector, _PollLikeSelector # type: ignore[attr-defined] +from typing import Any, Callable, Iterable from prometheus_client import Histogram, Metric from prometheus_client.core import REGISTRY, GaugeMetricFamily -from twisted.internet import reactor +from twisted.internet import reactor, selectreactor +from twisted.internet.asyncioreactor import AsyncioSelectorReactor from synapse.metrics._types import Collector +try: + from selectors import KqueueSelector +except ImportError: + + class KqueueSelector: # type: ignore[no-redef] + pass + + +try: + from twisted.internet.epollreactor import EPollReactor +except ImportError: + + class EPollReactor: # type: ignore[no-redef] + pass + + +try: + from twisted.internet.pollreactor import PollReactor +except ImportError: + + class PollReactor: # type: ignore[no-redef] + pass + + +logger = logging.getLogger(__name__) + # # Twisted reactor metrics # @@ -34,52 +62,100 @@ tick_time = Histogram( ) -class EpollWrapper: - """a wrapper for an epoll object which records the time between polls""" +class CallWrapper: + """A wrapper for a callable which records the time between calls""" - def __init__(self, poller: "select.epoll"): # type: ignore[name-defined] + def __init__(self, wrapped: Callable[..., Any]): self.last_polled = time.time() - self._poller = poller + self._wrapped = wrapped - def poll(self, *args, **kwargs) -> List[Tuple[int, int]]: # type: ignore[no-untyped-def] - # record the time since poll() was last called. This gives a good proxy for + def __call__(self, *args, **kwargs) -> Any: # type: ignore[no-untyped-def] + # record the time since this was last called. This gives a good proxy for # how long it takes to run everything in the reactor - ie, how long anything # waiting for the next tick will have to wait. tick_time.observe(time.time() - self.last_polled) - ret = self._poller.poll(*args, **kwargs) + ret = self._wrapped(*args, **kwargs) self.last_polled = time.time() return ret + +class ObjWrapper: + """A wrapper for an object which wraps a specified method in CallWrapper. + + Other methods/attributes are passed to the original object. + + This is necessary when the wrapped object does not allow the attribute to be + overwritten. + """ + + def __init__(self, wrapped: Any, method_name: str): + self._wrapped = wrapped + self._method_name = method_name + self._wrapped_method = CallWrapper(getattr(wrapped, method_name)) + def __getattr__(self, item: str) -> Any: - return getattr(self._poller, item) + if item == self._method_name: + return self._wrapped_method + + return getattr(self._wrapped, item) class ReactorLastSeenMetric(Collector): - def __init__(self, epoll_wrapper: EpollWrapper): - self._epoll_wrapper = epoll_wrapper + def __init__(self, call_wrapper: CallWrapper): + self._call_wrapper = call_wrapper def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily( "python_twisted_reactor_last_seen", "Seconds since the Twisted reactor was last seen", ) - cm.add_metric([], time.time() - self._epoll_wrapper.last_polled) + cm.add_metric([], time.time() - self._call_wrapper.last_polled) yield cm +# Twisted has already select a reasonable reactor for us, so assumptions can be +# made about the shape. +wrapper = None try: - # if the reactor has a `_poller` attribute, which is an `epoll` object - # (ie, it's an EPollReactor), we wrap the `epoll` with a thing that will - # measure the time between ticks - from select import epoll # type: ignore[attr-defined] - - poller = reactor._poller # type: ignore[attr-defined] -except (AttributeError, ImportError): - pass -else: - if isinstance(poller, epoll): - poller = EpollWrapper(poller) - reactor._poller = poller # type: ignore[attr-defined] - REGISTRY.register(ReactorLastSeenMetric(poller)) + if isinstance(reactor, (PollReactor, EPollReactor)): + reactor._poller = ObjWrapper(reactor._poller, "poll") # type: ignore[attr-defined] + wrapper = reactor._poller._wrapped_method # type: ignore[attr-defined] + + elif isinstance(reactor, selectreactor.SelectReactor): + # Twisted uses a module-level _select function. + wrapper = selectreactor._select = CallWrapper(selectreactor._select) + + elif isinstance(reactor, AsyncioSelectorReactor): + # For asyncio look at the underlying asyncio event loop. + asyncio_loop = reactor._asyncioEventloop # A sub-class of BaseEventLoop, + + # A sub-class of BaseSelector. + selector = asyncio_loop._selector # type: ignore[attr-defined] + + if isinstance(selector, SelectSelector): + wrapper = selector._select = CallWrapper(selector._select) # type: ignore[attr-defined] + + # poll, epoll, and /dev/poll. + elif isinstance(selector, _PollLikeSelector): + selector._selector = ObjWrapper(selector._selector, "poll") # type: ignore[attr-defined] + wrapper = selector._selector._wrapped_method # type: ignore[attr-defined] + + elif isinstance(selector, KqueueSelector): + selector._selector = ObjWrapper(selector._selector, "control") # type: ignore[attr-defined] + wrapper = selector._selector._wrapped_method # type: ignore[attr-defined] + + else: + # E.g. this does not support the (Windows-only) ProactorEventLoop. + logger.warning( + "Skipping configuring ReactorLastSeenMetric: unexpected asyncio loop selector: %r via %r", + selector, + asyncio_loop, + ) +except Exception as e: + logger.warning("Configuring ReactorLastSeenMetric failed: %r", e) + + +if wrapper: + REGISTRY.register(ReactorLastSeenMetric(wrapper)) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 14784312dc..5934b1ef34 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -25,10 +25,13 @@ from typing import ( Sequence, Tuple, Union, + cast, ) from prometheus_client import Counter +from twisted.internet.defer import Deferred + from synapse.api.constants import ( MAIN_TIMELINE, EventContentFields, @@ -40,11 +43,15 @@ from synapse.api.room_versions import PushRuleRoomFlag from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership +from synapse.storage.roommember import ProfileInfo from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.types import JsonValue from synapse.types.state import StateFilter +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import gather_results from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state @@ -342,15 +349,41 @@ class BulkPushRuleEvaluator: rules_by_user = await self._get_rules_for_event(event) actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {} - room_member_count = await self.store.get_number_joined_users_in_room( - event.room_id - ) - + # Gather a bunch of info in parallel. + # + # This has a lot of ignored types and casting due to the use of @cached + # decorated functions passed into run_in_background. + # + # See https://github.com/matrix-org/synapse/issues/16606 ( - power_levels, - sender_power_level, - ) = await self._get_power_levels_and_sender_level( - event, context, event_id_to_event + room_member_count, + (power_levels, sender_power_level), + related_events, + profiles, + ) = await make_deferred_yieldable( + cast( + "Deferred[Tuple[int, Tuple[dict, Optional[int]], Dict[str, Dict[str, JsonValue]], Mapping[str, ProfileInfo]]]", + gather_results( + ( + run_in_background( # type: ignore[call-arg] + self.store.get_number_joined_users_in_room, event.room_id # type: ignore[arg-type] + ), + run_in_background( + self._get_power_levels_and_sender_level, + event, + context, + event_id_to_event, + ), + run_in_background(self._related_events, event), + run_in_background( # type: ignore[call-arg] + self.store.get_subset_users_in_room_with_profiles, + event.room_id, # type: ignore[arg-type] + rules_by_user.keys(), # type: ignore[arg-type] + ), + ), + consumeErrors=True, + ).addErrback(unwrapFirstError), + ) ) # Find the event's thread ID. @@ -366,8 +399,6 @@ class BulkPushRuleEvaluator: # the parent is part of a thread. thread_id = await self.store.get_thread_id(relation.parent_id) - related_events = await self._related_events(event) - # It's possible that old room versions have non-integer power levels (floats or # strings; even the occasional `null`). For old rooms, we interpret these as if # they were integers. Do this here for the `@room` power level threshold. @@ -400,11 +431,6 @@ class BulkPushRuleEvaluator: self.hs.config.experimental.msc1767_enabled, # MSC3931 flag ) - users = rules_by_user.keys() - profiles = await self.store.get_subset_users_in_room_with_profiles( - event.room_id, users - ) - for uid, rules in rules_by_user.items(): if event.sender == uid: continue diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6d54bb0eb2..f50a4ce2fc 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -1117,7 +1117,7 @@ class DatabasePool: txn: LoggingTransaction, table: str, keys: Collection[str], - values: Iterable[Iterable[Any]], + values: Collection[Iterable[Any]], ) -> None: """Executes an INSERT query on the named table. @@ -1130,6 +1130,9 @@ class DatabasePool: keys: list of column names values: for each row, a list of values in the same order as `keys` """ + # If there's nothing to insert, then skip executing the query. + if not values: + return if isinstance(txn.database_engine, PostgresEngine): # We use `execute_values` as it can be a lot faster than `execute_batch`, @@ -1401,12 +1404,12 @@ class DatabasePool: allvalues.update(values) latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % ( + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %sDO %s" % ( table, ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues), ", ".join(k for k in keyvalues), - f"WHERE {where_clause}" if where_clause else "", + f"WHERE {where_clause} " if where_clause else "", latter, ) txn.execute(sql, list(allvalues.values())) @@ -1455,7 +1458,7 @@ class DatabasePool: key_names: Collection[str], key_values: Collection[Iterable[Any]], value_names: Collection[str], - value_values: Iterable[Iterable[Any]], + value_values: Collection[Iterable[Any]], ) -> None: """ Upsert, many times. @@ -1468,6 +1471,19 @@ class DatabasePool: value_values: A list of each row's value column values. Ignored if value_names is empty. """ + # If there's nothing to upsert, then skip executing the query. + if not key_values: + return + + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not value_names: + value_values = [() for x in range(len(key_values))] + elif len(value_values) != len(key_values): + raise ValueError( + f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number." + ) + if table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( txn, table, key_names, key_values, value_names, value_values @@ -1502,10 +1518,6 @@ class DatabasePool: value_values: A list of each row's value column values. Ignored if value_names is empty. """ - # No value columns, therefore make a blank list so that the following - # zip() works correctly. - if not value_names: - value_values = [() for x in range(len(key_values))] # Lock the table just once, to prevent it being done once per row. # Note that, according to Postgres' documentation, once obtained, @@ -1543,10 +1555,7 @@ class DatabasePool: allnames.extend(value_names) if not value_names: - # No value columns, therefore make a blank list so that the - # following zip() works correctly. latter = "NOTHING" - value_values = [() for x in range(len(key_values))] else: latter = "UPDATE SET " + ", ".join( k + "=EXCLUDED." + k for k in value_names @@ -1910,6 +1919,7 @@ class DatabasePool: Returns: The results as a list of tuples. """ + # If there's nothing to select, then skip executing the query. if not iterable: return [] @@ -2044,13 +2054,13 @@ class DatabasePool: raise ValueError( f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number." ) + # If there is nothing to update, then skip executing the query. + if not key_values: + return # List of tuples of (value values, then key values) # (This matches the order needed for the query) - args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)] - - for ks, vs in zip(key_values, value_values): - args.append(tuple(vs) + tuple(ks)) + args = [tuple(vv) + tuple(kv) for vv, kv in zip(value_values, key_values)] # 'col1 = ?, col2 = ?, ...' set_clause = ", ".join(f"{n} = ?" for n in value_names) @@ -2062,9 +2072,7 @@ class DatabasePool: where_clause = "" # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ? - sql = f""" - UPDATE {table} SET {set_clause} {where_clause} - """ + sql = f"UPDATE {table} SET {set_clause} {where_clause}" txn.execute_batch(sql, args) @@ -2280,11 +2288,10 @@ class DatabasePool: Returns: Number rows deleted """ + # If there's nothing to delete, then skip executing the query. if not values: return 0 - sql = "DELETE FROM %s" % table - clause, values = make_in_list_sql_clause(txn.database_engine, column, values) clauses = [clause] @@ -2292,8 +2299,7 @@ class DatabasePool: clauses.append("%s = ?" % (key,)) values.append(value) - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + sql = "DELETE FROM %s WHERE %s" % (table, " AND ".join(clauses)) txn.execute(sql, values) return txn.rowcount diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 303ef6ea27..34d6c52e39 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -705,7 +705,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): key_names=("destination", "user_id"), key_values=[(destination, user_id) for user_id, _ in rows], value_names=("stream_id",), - value_values=((stream_id,) for _, stream_id in rows), + value_values=[(stream_id,) for _, stream_id in rows], ) # Delete all sent outbound pokes diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 647ba182f6..7c34bde3e5 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -1476,7 +1476,7 @@ class PersistEventsStore: txn, table="event_json", keys=("event_id", "room_id", "internal_metadata", "json", "format_version"), - values=( + values=[ ( event.event_id, event.room_id, @@ -1485,7 +1485,7 @@ class PersistEventsStore: event.format_version, ) for event, _ in events_and_contexts - ), + ], ) self.db_pool.simple_insert_many_txn( @@ -1508,7 +1508,7 @@ class PersistEventsStore: "state_key", "rejection_reason", ), - values=( + values=[ ( self._instance_name, event.internal_metadata.stream_ordering, @@ -1527,7 +1527,7 @@ class PersistEventsStore: context.rejected, ) for event, context in events_and_contexts - ), + ], ) # If we're persisting an unredacted event we go and ensure @@ -1550,11 +1550,11 @@ class PersistEventsStore: txn, table="state_events", keys=("event_id", "room_id", "type", "state_key"), - values=( + values=[ (event.event_id, event.room_id, event.type, event.state_key) for event, _ in events_and_contexts if event.is_state() - ), + ], ) def _store_rejected_events_txn( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 22025eca56..37135d431d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -28,8 +28,11 @@ from typing import ( cast, ) +from twisted.internet import defer + from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -51,7 +54,8 @@ from synapse.storage.util.id_generators import ( ) from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.types import JsonDict -from synapse.util import json_encoder +from synapse.util import json_encoder, unwrapFirstError +from synapse.util.async_helpers import gather_results from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -249,23 +253,33 @@ class PushRulesWorkerStore( user_id: [] for user_id in user_ids } - rows = cast( - List[Tuple[str, str, int, int, str, str]], - await self.db_pool.simple_select_many_batch( - table="push_rules", - column="user_name", - iterable=user_ids, - retcols=( - "user_name", - "rule_id", - "priority_class", - "priority", - "conditions", - "actions", + # gatherResults loses all type information. + rows, enabled_map_by_user = await make_deferred_yieldable( + gather_results( + ( + cast( + "defer.Deferred[List[Tuple[str, str, int, int, str, str]]]", + run_in_background( + self.db_pool.simple_select_many_batch, + table="push_rules", + column="user_name", + iterable=user_ids, + retcols=( + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ), + desc="bulk_get_push_rules", + batch_size=1000, + ), + ), + run_in_background(self.bulk_get_push_rules_enabled, user_ids), ), - desc="bulk_get_push_rules", - batch_size=1000, - ), + consumeErrors=True, + ).addErrback(unwrapFirstError) ) # Sort by highest priority_class, then highest priority. @@ -276,8 +290,6 @@ class PushRulesWorkerStore( (rule_id, priority_class, conditions, actions) ) - enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) - results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6d4b9891e7..afb880532e 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -2268,7 +2268,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): txn, table="partial_state_rooms_servers", keys=("room_id", "server_name"), - values=((room_id, s) for s in servers), + values=[(room_id, s) for s in servers], ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) self._invalidate_cache_and_stream( diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index dbde9130c6..f4bef4c99b 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -106,7 +106,7 @@ class SearchWorkerStore(SQLBaseStore): txn, table="event_search", keys=("event_id", "room_id", "key", "value"), - values=( + values=[ ( entry.event_id, entry.room_id, @@ -114,7 +114,7 @@ class SearchWorkerStore(SQLBaseStore): _clean_value_for_search(entry.value), ) for entry in entries - ), + ], ) else: diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 0cbeb0c365..8a55e4e41d 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -345,6 +345,7 @@ async def yieldable_gather_results_delaying_cancellation( T1 = TypeVar("T1") T2 = TypeVar("T2") T3 = TypeVar("T3") +T4 = TypeVar("T4") @overload @@ -380,6 +381,19 @@ def gather_results( ... +@overload +def gather_results( + deferredList: Tuple[ + "defer.Deferred[T1]", + "defer.Deferred[T2]", + "defer.Deferred[T3]", + "defer.Deferred[T4]", + ], + consumeErrors: bool = ..., +) -> "defer.Deferred[Tuple[T1, T2, T3, T4]]": + ... + + def gather_results( # type: ignore[misc] deferredList: Tuple["defer.Deferred[T1]", ...], consumeErrors: bool = False, diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index e4a52c301e..f34b6b2dcf 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py
@@ -14,7 +14,7 @@ from collections import OrderedDict from typing import Generator -from unittest.mock import Mock +from unittest.mock import Mock, call, patch from twisted.internet import defer @@ -24,43 +24,90 @@ from synapse.storage.engines import create_engine from tests import unittest from tests.server import TestHomeServer -from tests.utils import default_config +from tests.utils import USE_POSTGRES_FOR_TESTS, default_config class SQLBaseStoreTestCase(unittest.TestCase): """Test the "simple" SQL generating methods in SQLBaseStore.""" def setUp(self) -> None: - self.db_pool = Mock(spec=["runInteraction"]) + # This is the Twisted connection pool. + conn_pool = Mock(spec=["runInteraction", "runWithConnection"]) self.mock_txn = Mock() - self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"]) + if USE_POSTGRES_FOR_TESTS: + # To avoid testing psycopg2 itself, patch execute_batch/execute_values + # to assert how it is called. + from psycopg2 import extras + + self.mock_execute_batch = Mock() + self.execute_batch_patcher = patch.object( + extras, "execute_batch", new=self.mock_execute_batch + ) + self.execute_batch_patcher.start() + self.mock_execute_values = Mock() + self.execute_values_patcher = patch.object( + extras, "execute_values", new=self.mock_execute_values + ) + self.execute_values_patcher.start() + + self.mock_conn = Mock( + spec_set=[ + "cursor", + "rollback", + "commit", + "closed", + "reconnect", + "set_session", + "encoding", + ] + ) + self.mock_conn.encoding = "UNICODE" + else: + self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"]) self.mock_conn.cursor.return_value = self.mock_txn + self.mock_txn.connection = self.mock_conn self.mock_conn.rollback.return_value = None # Our fake runInteraction just runs synchronously inline def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def] return defer.succeed(func(self.mock_txn, *args, **kwargs)) - self.db_pool.runInteraction = runInteraction + conn_pool.runInteraction = runInteraction def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def] return defer.succeed(func(self.mock_conn, *args, **kwargs)) - self.db_pool.runWithConnection = runWithConnection + conn_pool.runWithConnection = runWithConnection config = default_config(name="test", parse=True) hs = TestHomeServer("test", config=config) - sqlite_config = {"name": "sqlite3"} - engine = create_engine(sqlite_config) + if USE_POSTGRES_FOR_TESTS: + db_config = {"name": "psycopg2", "args": {}} + else: + db_config = {"name": "sqlite3"} + engine = create_engine(db_config) + fake_engine = Mock(wraps=engine) fake_engine.in_transaction.return_value = False + fake_engine.module.OperationalError = engine.module.OperationalError + fake_engine.module.DatabaseError = engine.module.DatabaseError + fake_engine.module.IntegrityError = engine.module.IntegrityError + # Don't convert param style to make assertions easier. + fake_engine.convert_param_style = lambda sql: sql + # To fix isinstance(...) checks. + fake_engine.__class__ = engine.__class__ # type: ignore[assignment] - db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) - db._db_pool = self.db_pool + db = DatabasePool(Mock(), Mock(config=db_config), fake_engine) + db._db_pool = conn_pool self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type] + def tearDown(self) -> None: + if USE_POSTGRES_FOR_TESTS: + self.execute_batch_patcher.stop() + self.execute_values_patcher.stop() + @defer.inlineCallbacks def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 @@ -71,7 +118,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "INSERT INTO tablename (columname) VALUES(?)", ("Value",) ) @@ -87,11 +134,66 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", (1, 2, 3) ) @defer.inlineCallbacks + def test_insert_many(self) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert_many( + table="tablename", + keys=( + "col1", + "col2", + ), + values=[ + ( + "val1", + "val2", + ), + ("val3", "val4"), + ], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_values.assert_called_once_with( + self.mock_txn, + "INSERT INTO tablename (col1, col2) VALUES ?", + [("val1", "val2"), ("val3", "val4")], + template=None, + fetch=False, + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "INSERT INTO tablename (col1, col2) VALUES(?, ?)", + [("val1", "val2"), ("val3", "val4")], + ) + + @defer.inlineCallbacks + def test_insert_many_no_iterable( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert_many( + table="tablename", + keys=( + "col1", + "col2", + ), + values=[], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_values.assert_not_called() + else: + self.mock_txn.executemany.assert_not_called() + + @defer.inlineCallbacks def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) @@ -103,7 +205,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) self.assertEqual("Value", value) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -121,7 +223,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -156,11 +258,59 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) self.assertEqual([(1,), (2,), (3,)], ret) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "SELECT colA FROM tablename WHERE keycol = ?", ["A set"] ) @defer.inlineCallbacks + def test_select_many_batch( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 3 + self.mock_txn.fetchall.side_effect = [[(1,), (2,)], [(3,)]] + + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_many_batch( + table="tablename", + column="col1", + iterable=("val1", "val2", "val3"), + retcols=("col2",), + keyvalues={"col3": "val4"}, + batch_size=2, + ) + ) + + self.mock_txn.execute.assert_has_calls( + [ + call( + "SELECT col2 FROM tablename WHERE col1 = ANY(?) AND col3 = ?", + [["val1", "val2"], "val4"], + ), + call( + "SELECT col2 FROM tablename WHERE col1 = ANY(?) AND col3 = ?", + [["val3"], "val4"], + ), + ], + ) + self.assertEqual([(1,), (2,), (3,)], ret) + + def test_select_many_no_iterable(self) -> None: + self.mock_txn.rowcount = 3 + self.mock_txn.fetchall.side_effect = [(1,), (2,)] + + ret = self.datastore.db_pool.simple_select_many_txn( + self.mock_txn, + table="tablename", + column="col1", + iterable=(), + retcols=("col2",), + keyvalues={"col3": "val4"}, + ) + + self.mock_txn.execute.assert_not_called() + self.assertEqual([], ret) + + @defer.inlineCallbacks def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 @@ -172,7 +322,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "UPDATE tablename SET columnname = ? WHERE keycol = ?", ["New Value", "TheKey"], ) @@ -191,12 +341,70 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?", [3, 4, 1, 2], ) @defer.inlineCallbacks + def test_update_many(self) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_many( + table="tablename", + key_names=("col1", "col2"), + key_values=[("val1", "val2")], + value_names=("col3",), + value_values=[("val3",)], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_batch.assert_called_once_with( + self.mock_txn, + "UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?", + [("val3", "val1", "val2")], + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?", + [("val3", "val1", "val2")], + ) + + # key_values and value_values must be the same length. + with self.assertRaises(ValueError): + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_many( + table="tablename", + key_names=("col1", "col2"), + key_values=[("val1", "val2")], + value_names=("col3",), + value_values=[], + desc="", + ) + ) + + @defer.inlineCallbacks + def test_update_many_no_iterable( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_many( + table="tablename", + key_names=("col1", "col2"), + key_values=[], + value_names=("col3",), + value_values=[], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_batch.assert_not_called() + else: + self.mock_txn.executemany.assert_not_called() + + @defer.inlineCallbacks def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 @@ -206,6 +414,393 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "DELETE FROM tablename WHERE keycol = ?", ["Go away"] ) + + @defer.inlineCallbacks + def test_delete_many(self) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 2 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_delete_many( + table="tablename", + column="col1", + iterable=("val1", "val2"), + keyvalues={"col2": "val3"}, + desc="", + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "DELETE FROM tablename WHERE col1 = ANY(?) AND col2 = ?", + [["val1", "val2"], "val3"], + ) + self.assertEqual(result, 2) + + @defer.inlineCallbacks + def test_delete_many_no_iterable( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_delete_many( + table="tablename", + column="col1", + iterable=(), + keyvalues={"col2": "val3"}, + desc="", + ) + ) + + self.mock_txn.execute.assert_not_called() + self.assertEqual(result, 0) + + @defer.inlineCallbacks + def test_delete_many_no_keyvalues( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 2 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_delete_many( + table="tablename", + column="col1", + iterable=("val1", "val2"), + keyvalues={}, + desc="", + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "DELETE FROM tablename WHERE col1 = ANY(?)", [["val1", "val2"]] + ) + self.assertEqual(result, 2) + + @defer.inlineCallbacks + def test_upsert(self) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "INSERT INTO tablename (columnname, othercol) VALUES (?, ?) ON CONFLICT (columnname) DO UPDATE SET othercol=EXCLUDED.othercol", + ["oldvalue", "newvalue"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_no_values( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "value"}, + values={}, + insertion_values={"columnname": "value"}, + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "INSERT INTO tablename (columnname) VALUES (?) ON CONFLICT (columnname) DO NOTHING", + ["value"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_with_insertion( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + insertion_values={"thirdcol": "insertionval"}, + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "INSERT INTO tablename (columnname, thirdcol, othercol) VALUES (?, ?, ?) ON CONFLICT (columnname) DO UPDATE SET othercol=EXCLUDED.othercol", + ["oldvalue", "insertionval", "newvalue"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_with_where( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + where_clause="thirdcol IS NULL", + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "INSERT INTO tablename (columnname, othercol) VALUES (?, ?) ON CONFLICT (columnname) WHERE thirdcol IS NULL DO UPDATE SET othercol=EXCLUDED.othercol", + ["oldvalue", "newvalue"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_many(self) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert_many( + table="tablename", + key_names=["keycol1", "keycol2"], + key_values=[["keyval1", "keyval2"], ["keyval3", "keyval4"]], + value_names=["valuecol3"], + value_values=[["val5"], ["val6"]], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_values.assert_called_once_with( + self.mock_txn, + "INSERT INTO tablename (keycol1, keycol2, valuecol3) VALUES ? ON CONFLICT (keycol1, keycol2) DO UPDATE SET valuecol3=EXCLUDED.valuecol3", + [("keyval1", "keyval2", "val5"), ("keyval3", "keyval4", "val6")], + template=None, + fetch=False, + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "INSERT INTO tablename (keycol1, keycol2, valuecol3) VALUES (?, ?, ?) ON CONFLICT (keycol1, keycol2) DO UPDATE SET valuecol3=EXCLUDED.valuecol3", + [("keyval1", "keyval2", "val5"), ("keyval3", "keyval4", "val6")], + ) + + @defer.inlineCallbacks + def test_upsert_many_no_values( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert_many( + table="tablename", + key_names=["columnname"], + key_values=[["oldvalue"]], + value_names=[], + value_values=[], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_values.assert_called_once_with( + self.mock_txn, + "INSERT INTO tablename (columnname) VALUES ? ON CONFLICT (columnname) DO NOTHING", + [("oldvalue",)], + template=None, + fetch=False, + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "INSERT INTO tablename (columnname) VALUES (?) ON CONFLICT (columnname) DO NOTHING", + [("oldvalue",)], + ) + + @defer.inlineCallbacks + def test_upsert_emulated_no_values_exists( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.fetchall.return_value = [(1,)] + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "value"}, + values={}, + insertion_values={"columnname": "value"}, + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_txn.execute.assert_has_calls( + [ + call("LOCK TABLE tablename in EXCLUSIVE MODE", ()), + call("SELECT 1 FROM tablename WHERE columnname = ?", ["value"]), + ] + ) + else: + self.mock_txn.execute.assert_called_once_with( + "SELECT 1 FROM tablename WHERE columnname = ?", ["value"] + ) + self.assertFalse(result) + + @defer.inlineCallbacks + def test_upsert_emulated_no_values_not_exists( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.fetchall.return_value = [] + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "value"}, + values={}, + insertion_values={"columnname": "value"}, + ) + ) + + self.mock_txn.execute.assert_has_calls( + [ + call( + "SELECT 1 FROM tablename WHERE columnname = ?", + ["value"], + ), + call("INSERT INTO tablename (columnname) VALUES (?)", ["value"]), + ], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_emulated_with_insertion_exists( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + insertion_values={"thirdcol": "insertionval"}, + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_txn.execute.assert_has_calls( + [ + call("LOCK TABLE tablename in EXCLUSIVE MODE", ()), + call( + "UPDATE tablename SET othercol = ? WHERE columnname = ?", + ["newvalue", "oldvalue"], + ), + ] + ) + else: + self.mock_txn.execute.assert_called_once_with( + "UPDATE tablename SET othercol = ? WHERE columnname = ?", + ["newvalue", "oldvalue"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_emulated_with_insertion_not_exists( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.rowcount = 0 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + insertion_values={"thirdcol": "insertionval"}, + ) + ) + + self.mock_txn.execute.assert_has_calls( + [ + call( + "UPDATE tablename SET othercol = ? WHERE columnname = ?", + ["newvalue", "oldvalue"], + ), + call( + "INSERT INTO tablename (columnname, othercol, thirdcol) VALUES (?, ?, ?)", + ["oldvalue", "newvalue", "insertionval"], + ), + ] + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_emulated_with_where( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + where_clause="thirdcol IS NULL", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_txn.execute.assert_has_calls( + [ + call("LOCK TABLE tablename in EXCLUSIVE MODE", ()), + call( + "UPDATE tablename SET othercol = ? WHERE columnname = ? AND thirdcol IS NULL", + ["newvalue", "oldvalue"], + ), + ] + ) + else: + self.mock_txn.execute.assert_called_once_with( + "UPDATE tablename SET othercol = ? WHERE columnname = ? AND thirdcol IS NULL", + ["newvalue", "oldvalue"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_emulated_with_where_no_values( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={}, + where_clause="thirdcol IS NULL", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_txn.execute.assert_has_calls( + [ + call("LOCK TABLE tablename in EXCLUSIVE MODE", ()), + call( + "SELECT 1 FROM tablename WHERE columnname = ? AND thirdcol IS NULL", + ["oldvalue"], + ), + ] + ) + else: + self.mock_txn.execute.assert_called_once_with( + "SELECT 1 FROM tablename WHERE columnname = ? AND thirdcol IS NULL", + ["oldvalue"], + ) + self.assertFalse(result)