From 09957ce0e4dcfd84c2de4039653059faae03065b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 4 Nov 2019 17:09:22 +0000 Subject: Implement per-room message retention policies --- tests/rest/client/test_retention.py | 320 ++++++++++++++++++++++++++++++++++++ 1 file changed, 320 insertions(+) create mode 100644 tests/rest/client/test_retention.py (limited to 'tests') diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py new file mode 100644 index 0000000000..41ea9db689 --- /dev/null +++ b/tests/rest/client/test_retention.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.api.constants import EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.visibility import filter_events_for_client + +from tests import unittest + +one_hour_ms = 3600000 +one_day_ms = one_hour_ms * 24 + + +class RetentionTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["default_room_version"] = "1" + config["retention"] = { + "enabled": True, + "default_policy": { + "min_lifetime": one_day_ms, + "max_lifetime": one_day_ms * 3, + }, + "allowed_lifetime_min": one_day_ms, + "allowed_lifetime_max": one_day_ms * 3, + } + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_retention_state_event(self): + """Tests that the server configuration can limit the values a user can set to the + room's retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_day_ms * 4, + }, + tok=self.token, + expect_code=400, + ) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_hour_ms, + }, + tok=self.token, + expect_code=400, + ) + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": lifetime, + }, + tok=self.token, + ) + + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_without_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by the server's configuration's default retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention_event_purged(room_id, one_day_ms * 2) + + def test_visibility(self): + """Tests that synapse.visibility.filter_events_for_client correctly filters out + outdated events + """ + store = self.hs.get_datastore() + storage = self.hs.get_storage() + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + events = [] + + # Send a first event, which should be filtered out at the end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + # Get the event from the store so that we end up with a FrozenEvent that we can + # give to filter_events_for_client. We need to do this now because the event won't + # be in the database anymore after it has expired. + events.append(self.get_success( + store.get_event( + resp.get("event_id") + ) + )) + + # Advance the time by 2 days. We're using the default retention policy, therefore + # after this the first event will still be valid. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Send another event, which shouldn't get filtered out. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + valid_event_id = resp.get("event_id") + + events.append(self.get_success( + store.get_event( + valid_event_id + ) + )) + + # Advance the time by anothe 2 days. After this, the first event should be + # outdated but not the second one. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Run filter_events_for_client with our list of FrozenEvents. + filtered_events = self.get_success(filter_events_for_client( + storage, self.user_id, events + )) + + # We should only get one event back. + self.assertEqual(len(filtered_events), 1, filtered_events) + # That event should be the second, not outdated event. + self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) + + def _test_retention_event_purged(self, room_id, increment): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + expired_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, expired_event_id) + self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + + # Advance the time. + self.reactor.advance(increment / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + valid_event_id = resp.get("event_id") + + # Advance the time again. Now our first event should have expired but our second + # one should still be kept. + self.reactor.advance(increment / 1000) + + # Check that the event has been purged from the database. + self.get_event(room_id, expired_event_id, expected_code=404) + + # Check that the event that hasn't been purged can still be retrieved. + valid_event = self.get_event(room_id, valid_event_id) + self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body + + +class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["default_room_version"] = "1" + config["retention"] = { + "enabled": True, + } + + mock_federation_client = Mock(spec=["backfill"]) + + self.hs = self.setup_test_homeserver( + config=config, + federation_client=mock_federation_client, + ) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_no_default_policy(self): + """Tests that an event doesn't get expired if there is neither a default retention + policy nor a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention(room_id) + + def test_state_policy(self): + """Tests that an event gets correctly expired if there is no default retention + policy but there's a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the maximum lifetime to 35 days so that the first event gets expired but not + # the second one. + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_day_ms * 35, + }, + tok=self.token, + ) + + self._test_retention(room_id, expected_code_for_first_event=404) + + def _test_retention(self, room_id, expected_code_for_first_event=200): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + first_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, first_event_id) + self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + + # Advance the time by a month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + second_event_id = resp.get("event_id") + + # Advance the time by another month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Check if the event has been purged from the database. + first_event = self.get_event( + room_id, first_event_id, expected_code=expected_code_for_first_event + ) + + if expected_code_for_first_event == 200: + self.assertEqual(first_event.get("content", {}).get("body"), "1", first_event) + + # Check that the event that hasn't been purged can still be retrieved. + second_event = self.get_event(room_id, second_event_id) + self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body -- cgit 1.5.1 From a7c818c79b70d6b70abc5b26f0e1e78fd60c087e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 13:21:26 +0000 Subject: Add test case --- tests/rest/client/v1/test_rooms.py | 182 +++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5e38fd6ced..621c894e35 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1106,3 +1106,185 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) + + +class ContextTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + def test_context_filter_labels(self): + """Test that we can filter by a label.""" + context_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + ) + + res = self._test_context_filter_labels(context_filter) + + self.assertEqual( + res["event"]["content"]["body"], "with right label", res["event"] + ) + + events_before = res["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "with right label", events_before[0] + ) + + events_after = res["events_before"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with right label", events_after[0] + ) + + def test_context_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" + context_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + ) + + res = self._test_context_filter_labels(context_filter) + + events_before = res["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "without label", events_before[0] + ) + + events_after = res["events_after"] + + self.assertEqual( + len(events_after), 2, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + self.assertEqual( + events_after[1]["content"]["body"], "with two wrong labels", events_after[1] + ) + + def test_context_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + context_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + ) + + res = self._test_context_filter_labels(context_filter) + + events_before = res["events_before"] + + self.assertEqual( + len(events_before), 0, [event["content"] for event in events_before] + ) + + events_after = res["events_after"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + + def _test_context_filter_labels(self, context_filter): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=tok, + ) + + # The event we'll look up the context for. + res = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + event_id = res["event_id"] + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" % (room_id, event_id, context_filter), + access_token=tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body + -- cgit 1.5.1 From c9e4748cb75271a2178d0cae05d551829249ada3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 13:47:47 +0000 Subject: Merge labels tests for /context and /messages --- tests/rest/client/v1/test_rooms.py | 276 +++++++++++++++++-------------------- 1 file changed, 130 insertions(+), 146 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 621c894e35..fe327d1bf8 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -811,105 +811,6 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_filter_labels(self): - """Test that we can filter by a label.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} - ) - - events = self._test_filter_labels(message_filter) - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - - def test_filter_not_labels(self): - """Test that we can filter by the absence of a label.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} - ) - - events = self._test_filter_labels(message_filter) - - self.assertEqual(len(events), 3, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual( - events[2]["content"]["body"], "with two wrong labels", events[2] - ) - - def test_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label.""" - sync_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - ) - - events = self._test_filter_labels(sync_filter) - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def _test_filter_labels(self, message_filter): - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - ) - - token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&filter=%s" - % (self.room_id, token, message_filter), - ) - self.render(request) - - return channel.json_body["chunk"] - class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ @@ -1108,7 +1009,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.assertEqual(res_displayname, self.displayname, channel.result) -class ContextTestCase(unittest.HomeserverTestCase): +class LabelsTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -1116,8 +1017,13 @@ class ContextTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + def test_context_filter_labels(self): - """Test that we can filter by a label.""" + """Test that we can filter by a label on a /context request.""" context_filter = json.dumps( { "types": [EventTypes.Message], @@ -1125,13 +1031,17 @@ class ContextTestCase(unittest.HomeserverTestCase): } ) - res = self._test_context_filter_labels(context_filter) + event_id = self._send_labelled_messages_in_room() - self.assertEqual( - res["event"]["content"]["body"], "with right label", res["event"] + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + access_token=self.tok, ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) - events_before = res["events_before"] + events_before = channel.json_body["events_before"] self.assertEqual( len(events_before), 1, [event["content"] for event in events_before] @@ -1140,7 +1050,7 @@ class ContextTestCase(unittest.HomeserverTestCase): events_before[0]["content"]["body"], "with right label", events_before[0] ) - events_after = res["events_before"] + events_after = channel.json_body["events_before"] self.assertEqual( len(events_after), 1, [event["content"] for event in events_after] @@ -1150,7 +1060,7 @@ class ContextTestCase(unittest.HomeserverTestCase): ) def test_context_filter_not_labels(self): - """Test that we can filter by the absence of a label.""" + """Test that we can filter by the absence of a label on a /context request.""" context_filter = json.dumps( { "types": [EventTypes.Message], @@ -1158,9 +1068,17 @@ class ContextTestCase(unittest.HomeserverTestCase): } ) - res = self._test_context_filter_labels(context_filter) + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) - events_before = res["events_before"] + events_before = channel.json_body["events_before"] self.assertEqual( len(events_before), 1, [event["content"] for event in events_before] @@ -1169,7 +1087,7 @@ class ContextTestCase(unittest.HomeserverTestCase): events_before[0]["content"]["body"], "without label", events_before[0] ) - events_after = res["events_after"] + events_after = channel.json_body["events_after"] self.assertEqual( len(events_after), 2, [event["content"] for event in events_after] @@ -1182,7 +1100,9 @@ class ContextTestCase(unittest.HomeserverTestCase): ) def test_context_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label.""" + """Test that we can filter by both a label and the absence of another label on a + /context request. + """ context_filter = json.dumps( { "types": [EventTypes.Message], @@ -1191,15 +1111,23 @@ class ContextTestCase(unittest.HomeserverTestCase): } ) - res = self._test_context_filter_labels(context_filter) + event_id = self._send_labelled_messages_in_room() - events_before = res["events_before"] + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] self.assertEqual( len(events_before), 0, [event["content"] for event in events_before] ) - events_after = res["events_after"] + events_after = channel.json_body["events_after"] self.assertEqual( len(events_after), 1, [event["content"] for event in events_after] @@ -1208,83 +1136,139 @@ class ContextTestCase(unittest.HomeserverTestCase): events_after[0]["content"]["body"], "with wrong label", events_after[0] ) - def _test_context_filter_labels(self, context_filter): - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") + def test_messages_filter_labels(self): + """Test that we can filter by a label on a /messages request.""" + message_filter = json.dumps( + {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} + ) + + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, message_filter), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_messages_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /messages request.""" + message_filter = json.dumps( + {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} + ) + + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, message_filter), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 4, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "without label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) + self.assertEqual( + events[3]["content"]["body"], "with two wrong labels", events[3] + ) + + def test_messages_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /messages request. + """ + message_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + ) + + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, message_filter), + ) + self.render(request) + + events = channel.json_body["chunk"] - room_id = self.helper.create_room_as(user_id, tok=tok) + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + def _send_labelled_messages_in_room(self): self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, content={ "msgtype": "m.text", "body": "with right label", EventContentFields.LABELS: ["#fun"], }, - tok=tok, + tok=self.tok, ) self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, content={"msgtype": "m.text", "body": "without label"}, - tok=tok, + tok=self.tok, ) - # The event we'll look up the context for. res = self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=tok, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, ) event_id = res["event_id"] self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, content={ "msgtype": "m.text", "body": "with wrong label", EventContentFields.LABELS: ["#work"], }, - tok=tok, + tok=self.tok, ) self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, content={ "msgtype": "m.text", "body": "with two wrong labels", EventContentFields.LABELS: ["#work", "#notfun"], }, - tok=tok, + tok=self.tok, ) self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, content={ "msgtype": "m.text", "body": "with right label", EventContentFields.LABELS: ["#fun"], }, - tok=tok, + tok=self.tok, ) - request, channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" % (room_id, event_id, context_filter), - access_token=tok, - ) - self.render(request) - self.assertEqual(channel.code, 200, channel.result) - - return channel.json_body - + return event_id -- cgit 1.5.1 From 037360e6cf2ca181b7cf03884375d4a4d52ad64e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 14:33:18 +0000 Subject: Add tests for /search --- tests/rest/client/v1/test_rooms.py | 187 ++++++++++++++++++++++++++++--------- 1 file changed, 143 insertions(+), 44 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index fe327d1bf8..cc7499dcc0 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1017,6 +1017,18 @@ class LabelsTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] + # Filter that should only catch messages with the label "#fun". + FILTER_LABELS = {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} + # Filter that should only catch messages without the label "#fun". + FILTER_NOT_LABELS = {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} + # Filter that should only catch messages with the label "#work" but without the label + # "#notfun". + FILTER_LABELS_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + def prepare(self, reactor, clock, homeserver): self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") @@ -1024,18 +1036,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_context_filter_labels(self): """Test that we can filter by a label on a /context request.""" - context_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#fun"], - } - ) - event_id = self._send_labelled_messages_in_room() request, channel = self.make_request( "GET", - "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), access_token=self.tok, ) self.render(request) @@ -1061,18 +1067,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_context_filter_not_labels(self): """Test that we can filter by the absence of a label on a /context request.""" - context_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.not_labels": ["#fun"], - } - ) - event_id = self._send_labelled_messages_in_room() request, channel = self.make_request( "GET", - "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), access_token=self.tok, ) self.render(request) @@ -1103,19 +1103,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /context request. """ - context_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - ) - event_id = self._send_labelled_messages_in_room() request, channel = self.make_request( "GET", - "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), access_token=self.tok, ) self.render(request) @@ -1138,17 +1131,13 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_messages_filter_labels(self): """Test that we can filter by a label on a /messages request.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} - ) - self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, message_filter), + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), ) self.render(request) @@ -1160,17 +1149,13 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_messages_filter_not_labels(self): """Test that we can filter by the absence of a label on a /messages request.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} - ) - self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, message_filter), + % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), ) self.render(request) @@ -1188,21 +1173,13 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /messages request. """ - message_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - ) - self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, message_filter), + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS_NOT_LABELS)), ) self.render(request) @@ -1211,7 +1188,128 @@ class LabelsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + def test_search_filter_labels(self): + """Test that we can filter by a label on a /search request.""" + request_data = json.dumps({ + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } + } + }) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 2, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with right label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "with right label", + results[1]["result"]["content"]["body"], + ) + + def test_search_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /search request.""" + request_data = json.dumps({ + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } + } + }) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 4, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "without label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "without label", + results[1]["result"]["content"]["body"], + ) + self.assertEqual( + results[2]["result"]["content"]["body"], + "with wrong label", + results[2]["result"]["content"]["body"], + ) + self.assertEqual( + results[3]["result"]["content"]["body"], + "with two wrong labels", + results[3]["result"]["content"]["body"], + ) + + def test_search_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /search request. + """ + request_data = json.dumps({ + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } + } + }) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 1, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with wrong label", + results[0]["result"]["content"]["body"], + ) + def _send_labelled_messages_in_room(self): + """Sends several messages to a room with different labels (or without any) to test + filtering by label. + + Returns: + The ID of the event to use if we're testing filtering on /context. + """ self.helper.send_event( room_id=self.room_id, type=EventTypes.Message, @@ -1236,6 +1334,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "without label"}, tok=self.tok, ) + # Return this event's ID when we test filtering in /context requests. event_id = res["event_id"] self.helper.send_event( -- cgit 1.5.1 From 8822b331114a2f6fdcd5916f0c91991c0acae07e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 10:56:39 +0000 Subject: Update copyrights --- synapse/api/constants.py | 3 ++- synapse/api/filtering.py | 3 +++ synapse/rest/client/versions.py | 3 +++ synapse/storage/data_stores/main/stream.py | 3 +++ tests/api/test_filtering.py | 3 +++ tests/rest/client/v1/test_rooms.py | 2 ++ tests/rest/client/v1/utils.py | 3 +++ tests/rest/client/v2_alpha/test_sync.py | 3 ++- 8 files changed, 21 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 49c4b85054..312acff3d6 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bec13f08d8..6eab1f13f0 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index bb30ce3f34..2a477ad22e 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 616ef91d4e..9cac664880 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 2dc5052249..63d8633582 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index cc7499dcc0..b2c1ef6f0e 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 8ea0cb05ea..e7417b3d14 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 3283c0e47b..661c1f88b9 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -- cgit 1.5.1 From a6863da24934dcbb2ae09a9e0b6e37140ef390ff Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 14:50:19 +0000 Subject: Lint --- tests/rest/client/v1/test_rooms.py | 71 ++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 30 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index b2c1ef6f0e..c5d67fc1cd 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1020,9 +1020,15 @@ class LabelsTestCase(unittest.HomeserverTestCase): ] # Filter that should only catch messages with the label "#fun". - FILTER_LABELS = {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} + FILTER_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } # Filter that should only catch messages without the label "#fun". - FILTER_NOT_LABELS = {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} + FILTER_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } # Filter that should only catch messages with the label "#work" but without the label # "#notfun". FILTER_LABELS_NOT_LABELS = { @@ -1181,7 +1187,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS_NOT_LABELS)), + % ( + self.room_id, + self.tok, + token, + json.dumps(self.FILTER_LABELS_NOT_LABELS), + ), ) self.render(request) @@ -1192,14 +1203,16 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_labels(self): """Test that we can filter by a label on a /search request.""" - request_data = json.dumps({ - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } } } - }) + ) self._send_labelled_messages_in_room() @@ -1211,9 +1224,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), - 2, - [result["result"]["content"] for result in results], + len(results), 2, [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], @@ -1228,14 +1239,16 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_not_labels(self): """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps({ - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } } } - }) + ) self._send_labelled_messages_in_room() @@ -1247,9 +1260,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), - 4, - [result["result"]["content"] for result in results], + len(results), 4, [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], @@ -1276,14 +1287,16 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /search request. """ - request_data = json.dumps({ - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } } } - }) + ) self._send_labelled_messages_in_room() @@ -1295,9 +1308,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), - 1, - [result["result"]["content"] for result in results], + len(results), 1, [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], -- cgit 1.5.1 From f03c9d34442368c8f2c26d2ac16b770bc451c76d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 6 Nov 2019 15:47:40 +0000 Subject: Don't apply retention policy based filtering on state events As per MSC1763, 'Retention is only considered for non-state events.', so don't filter out state events based on the room's retention policy. --- synapse/visibility.py | 15 +++++++++------ tests/rest/client/test_retention.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/synapse/visibility.py b/synapse/visibility.py index 4498c156bc..4d4141dacc 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -111,14 +111,17 @@ def filter_events_for_client( if not event.is_state() and event.sender in ignore_list: return None - retention_policy = retention_policies[event.room_id] - max_lifetime = retention_policy.get("max_lifetime") + # Don't try to apply the room's retention policy if the event is a state event, as + # MSC1763 states that retention is only considered for non-state events. + if not event.is_state(): + retention_policy = retention_policies[event.room_id] + max_lifetime = retention_policy.get("max_lifetime") - if max_lifetime is not None: - oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime + if max_lifetime is not None: + oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime - if event.origin_server_ts < oldest_allowed_ts: - return None + if event.origin_server_ts < oldest_allowed_ts: + return None if event.event_id in always_include_ids: return event diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 41ea9db689..7b6f25a838 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -164,6 +164,12 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) def _test_retention_event_purged(self, room_id, increment): + # Get the create event to, later, check that we can still access it. + message_handler = self.hs.get_message_handler() + create_event = self.get_success( + message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + ) + # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. resp = self.helper.send( @@ -202,6 +208,10 @@ class RetentionTestCase(unittest.HomeserverTestCase): valid_event = self.get_event(room_id, valid_event_id) self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) + # Check that we can still access state events that were sent before the event that + # has been purged. + self.get_event(room_id, create_event.event_id) + def get_event(self, room_id, event_id, expected_code=200): url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) -- cgit 1.5.1 From 7c24d0f443724082376c89f9f75954d81f524a8e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 19 Nov 2019 13:22:37 +0000 Subject: Lint --- synapse/config/server.py | 39 ++++++++++------- synapse/handlers/pagination.py | 17 +++----- synapse/storage/data_stores/main/room.py | 49 +++++++++++---------- tests/rest/client/test_retention.py | 73 ++++++++++---------------------- 4 files changed, 77 insertions(+), 101 deletions(-) (limited to 'tests') diff --git a/synapse/config/server.py b/synapse/config/server.py index aa93a416f1..8a55ffac4f 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import List +from typing import List, Dict, Optional import attr import yaml @@ -287,13 +287,17 @@ class ServerConfig(Config): self.retention_default_min_lifetime = None self.retention_default_max_lifetime = None - self.retention_allowed_lifetime_min = retention_config.get("allowed_lifetime_min") + self.retention_allowed_lifetime_min = retention_config.get( + "allowed_lifetime_min" + ) if self.retention_allowed_lifetime_min is not None: self.retention_allowed_lifetime_min = self.parse_duration( self.retention_allowed_lifetime_min ) - self.retention_allowed_lifetime_max = retention_config.get("allowed_lifetime_max") + self.retention_allowed_lifetime_max = retention_config.get( + "allowed_lifetime_max" + ) if self.retention_allowed_lifetime_max is not None: self.retention_allowed_lifetime_max = self.parse_duration( self.retention_allowed_lifetime_max @@ -302,14 +306,15 @@ class ServerConfig(Config): if ( self.retention_allowed_lifetime_min is not None and self.retention_allowed_lifetime_max is not None - and self.retention_allowed_lifetime_min > self.retention_allowed_lifetime_max + and self.retention_allowed_lifetime_min + > self.retention_allowed_lifetime_max ): raise ConfigError( "Invalid retention policy limits: 'allowed_lifetime_min' can not be" " greater than 'allowed_lifetime_max'" ) - self.retention_purge_jobs = [] + self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]] for purge_job_config in retention_config.get("purge_jobs", []): interval_config = purge_job_config.get("interval") @@ -342,18 +347,22 @@ class ServerConfig(Config): " 'longest_max_lifetime' value." ) - self.retention_purge_jobs.append({ - "interval": interval, - "shortest_max_lifetime": shortest_max_lifetime, - "longest_max_lifetime": longest_max_lifetime, - }) + self.retention_purge_jobs.append( + { + "interval": interval, + "shortest_max_lifetime": shortest_max_lifetime, + "longest_max_lifetime": longest_max_lifetime, + } + ) if not self.retention_purge_jobs: - self.retention_purge_jobs = [{ - "interval": self.parse_duration("1d"), - "shortest_max_lifetime": None, - "longest_max_lifetime": None, - }] + self.retention_purge_jobs = [ + { + "interval": self.parse_duration("1d"), + "shortest_max_lifetime": None, + "longest_max_lifetime": None, + } + ] self.listeners = [] # type: List[dict] for listener in config.get("listeners", []): diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index e1800177fa..d122c11a4d 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -154,20 +154,17 @@ class PaginationHandler(object): # Figure out what token we should start purging at. ts = self.clock.time_msec() - max_lifetime - stream_ordering = ( - yield self.store.find_first_stream_ordering_after_ts(ts) - ) + stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) - r = ( - yield self.store.get_room_event_after_stream_ordering( - room_id, stream_ordering, - ) + r = yield self.store.get_room_event_after_stream_ordering( + room_id, stream_ordering, ) if not r: logger.warning( "[purge] purging events not possible: No event found " "(ts %i => stream_ordering %i)", - ts, stream_ordering, + ts, + stream_ordering, ) continue @@ -186,9 +183,7 @@ class PaginationHandler(object): # the background so that it's not blocking any other operation apart from # other purges in the same room. run_as_background_process( - "_purge_history", - self._purge_history, - purge_id, room_id, token, True, + "_purge_history", self._purge_history, purge_id, room_id, token, True, ) def start_purge_history(self, room_id, token, delete_local_events=False): diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 54a7d24c73..7fceae59ca 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -334,8 +334,9 @@ class RoomStore(RoomWorkerStore, SearchStore): WHERE state.room_id > ? AND state.type = '%s' ORDER BY state.room_id ASC LIMIT ?; - """ % EventTypes.Retention, - (last_room, batch_size) + """ + % EventTypes.Retention, + (last_room, batch_size), ) rows = self.cursor_to_dict(txn) @@ -358,15 +359,13 @@ class RoomStore(RoomWorkerStore, SearchStore): "event_id": row["event_id"], "min_lifetime": retention_policy.get("min_lifetime"), "max_lifetime": retention_policy.get("max_lifetime"), - } + }, ) logger.info("Inserted %d rows into room_retention", len(rows)) self._background_update_progress_txn( - txn, "insert_room_retention", { - "room_id": rows[-1]["room_id"], - } + txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} ) if batch_size > len(rows): @@ -375,8 +374,7 @@ class RoomStore(RoomWorkerStore, SearchStore): return False end = yield self.runInteraction( - "insert_room_retention", - _background_insert_retention_txn, + "insert_room_retention", _background_insert_retention_txn, ) if end: @@ -585,17 +583,15 @@ class RoomStore(RoomWorkerStore, SearchStore): ) def _store_retention_policy_for_room_txn(self, txn, event): - if ( - hasattr(event, "content") - and ("min_lifetime" in event.content or "max_lifetime" in event.content) + if hasattr(event, "content") and ( + "min_lifetime" in event.content or "max_lifetime" in event.content ): if ( - ("min_lifetime" in event.content and not isinstance( - event.content.get("min_lifetime"), integer_types - )) - or ("max_lifetime" in event.content and not isinstance( - event.content.get("max_lifetime"), integer_types - )) + "min_lifetime" in event.content + and not isinstance(event.content.get("min_lifetime"), integer_types) + ) or ( + "max_lifetime" in event.content + and not isinstance(event.content.get("max_lifetime"), integer_types) ): # Ignore the event if one of the value isn't an integer. return @@ -798,7 +794,9 @@ class RoomStore(RoomWorkerStore, SearchStore): return local_media_mxcs, remote_media_mxcs @defer.inlineCallbacks - def get_rooms_for_retention_period_in_range(self, min_ms, max_ms, include_null=False): + def get_rooms_for_retention_period_in_range( + self, min_ms, max_ms, include_null=False + ): """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. @@ -904,23 +902,24 @@ class RoomStore(RoomWorkerStore, SearchStore): INNER JOIN current_state_events USING (event_id, room_id) WHERE room_id = ?; """, - (room_id,) + (room_id,), ) return self.cursor_to_dict(txn) ret = yield self.runInteraction( - "get_retention_policy_for_room", - get_retention_policy_for_room_txn, + "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) # If we don't know this room ID, ret will be None, in this case return the default # policy. if not ret: - defer.returnValue({ - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, - }) + defer.returnValue( + { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } + ) row = ret[0] diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 7b6f25a838..6bf485c239 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -61,9 +61,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": one_day_ms * 4, - }, + body={"max_lifetime": one_day_ms * 4}, tok=self.token, expect_code=400, ) @@ -71,9 +69,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": one_hour_ms, - }, + body={"max_lifetime": one_hour_ms}, tok=self.token, expect_code=400, ) @@ -89,9 +85,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": lifetime, - }, + body={"max_lifetime": lifetime}, tok=self.token, ) @@ -115,20 +109,12 @@ class RetentionTestCase(unittest.HomeserverTestCase): events = [] # Send a first event, which should be filtered out at the end of the test. - resp = self.helper.send( - room_id=room_id, - body="1", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) # Get the event from the store so that we end up with a FrozenEvent that we can # give to filter_events_for_client. We need to do this now because the event won't # be in the database anymore after it has expired. - events.append(self.get_success( - store.get_event( - resp.get("event_id") - ) - )) + events.append(self.get_success(store.get_event(resp.get("event_id")))) # Advance the time by 2 days. We're using the default retention policy, therefore # after this the first event will still be valid. @@ -143,20 +129,16 @@ class RetentionTestCase(unittest.HomeserverTestCase): valid_event_id = resp.get("event_id") - events.append(self.get_success( - store.get_event( - valid_event_id - ) - )) + events.append(self.get_success(store.get_event(valid_event_id))) # Advance the time by anothe 2 days. After this, the first event should be # outdated but not the second one. self.reactor.advance(one_day_ms * 2 / 1000) # Run filter_events_for_client with our list of FrozenEvents. - filtered_events = self.get_success(filter_events_for_client( - storage, self.user_id, events - )) + filtered_events = self.get_success( + filter_events_for_client(storage, self.user_id, events) + ) # We should only get one event back. self.assertEqual(len(filtered_events), 1, filtered_events) @@ -172,28 +154,22 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. - resp = self.helper.send( - room_id=room_id, - body="1", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) expired_event_id = resp.get("event_id") # Check that we can retrieve the event. expired_event = self.get_event(room_id, expired_event_id) - self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) # Advance the time. self.reactor.advance(increment / 1000) # Send another event. We need this because the purge job won't purge the most # recent event in the room. - resp = self.helper.send( - room_id=room_id, - body="2", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") @@ -240,8 +216,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["backfill"]) self.hs = self.setup_test_homeserver( - config=config, - federation_client=mock_federation_client, + config=config, federation_client=mock_federation_client, ) return self.hs @@ -268,9 +243,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": one_day_ms * 35, - }, + body={"max_lifetime": one_day_ms * 35}, tok=self.token, ) @@ -289,18 +262,16 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): # Check that we can retrieve the event. expired_event = self.get_event(room_id, first_event_id) - self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) # Advance the time by a month. self.reactor.advance(one_day_ms * 30 / 1000) # Send another event. We need this because the purge job won't purge the most # recent event in the room. - resp = self.helper.send( - room_id=room_id, - body="2", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) second_event_id = resp.get("event_id") @@ -313,7 +284,9 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): ) if expected_code_for_first_event == 200: - self.assertEqual(first_event.get("content", {}).get("body"), "1", first_event) + self.assertEqual( + first_event.get("content", {}).get("body"), "1", first_event + ) # Check that the event that hasn't been purged can still be retrieved. second_event = self.get_event(room_id, second_event_id) -- cgit 1.5.1 From bf9a11c54d3c9ca1e4fa9420b567efceb1e46d5b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 19 Nov 2019 13:30:04 +0000 Subject: Lint again --- synapse/config/server.py | 2 +- tests/rest/client/test_retention.py | 12 ++---------- 2 files changed, 3 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/synapse/config/server.py b/synapse/config/server.py index 8a55ffac4f..ed916e1400 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import List, Dict, Optional +from typing import Dict, List, Optional import attr import yaml diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 6bf485c239..9e549d8a91 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -121,11 +121,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.reactor.advance(one_day_ms * 2 / 1000) # Send another event, which shouldn't get filtered out. - resp = self.helper.send( - room_id=room_id, - body="2", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") @@ -252,11 +248,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def _test_retention(self, room_id, expected_code_for_first_event=200): # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. - resp = self.helper.send( - room_id=room_id, - body="1", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) first_event_id = resp.get("event_id") -- cgit 1.5.1 From 6356f2088f0adb681fe24a8435955b19883fa3b4 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 20 Nov 2019 12:09:06 +0000 Subject: Test if a purge can make /messages return 500 responses --- tests/rest/client/v1/test_rooms.py | 72 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5e38fd6ced..ebaa67e899 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -25,7 +25,9 @@ from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import login, profile, room +from synapse.util.stringutils import random_string from tests import unittest @@ -910,6 +912,76 @@ class RoomMessageListTestCase(RoomBase): return channel.json_body["chunk"] + def test_room_messages_purge(self): + store = self.hs.get_datastore() + pagination_handler = self.hs.get_pagination_handler() + + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) + ) + + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) + ) + + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, "message 3") + + # Check that we get the first and second message when querying /messages. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) + + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success(pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token, + delete_local_events=True, + )) + + # Check that we only get the second message through /message now that the first + # has been purged. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) + + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From e2a20326e8141fdf9304434901da38c64b917a78 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 20 Nov 2019 15:08:47 +0000 Subject: Lint --- tests/rest/client/v1/test_rooms.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ebaa67e899..e84e578f99 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -948,12 +948,14 @@ class RoomMessageListTestCase(RoomBase): # Purge every event before the second event. purge_id = random_string(16) pagination_handler._purges_by_id[purge_id] = PurgeStatus() - self.get_success(pagination_handler._purge_history( - purge_id=purge_id, - room_id=self.room_id, - token=second_token, - delete_local_events=True, - )) + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token, + delete_local_events=True, + ) + ) # Check that we only get the second message through /message now that the first # has been purged. -- cgit 1.5.1 From 9eebd46048d0b34767047b2156760a1467f19ae6 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 26 Nov 2019 03:45:50 +1100 Subject: Improve the performance of structured logging (#6322) --- changelog.d/6322.misc | 1 + synapse/logging/_structured.py | 14 +++++- synapse/logging/_terse_json.py | 106 ++++++++++++++++++++++++++++++----------- tests/server.py | 2 + 4 files changed, 93 insertions(+), 30 deletions(-) create mode 100644 changelog.d/6322.misc (limited to 'tests') diff --git a/changelog.d/6322.misc b/changelog.d/6322.misc new file mode 100644 index 0000000000..70ef36ca80 --- /dev/null +++ b/changelog.d/6322.misc @@ -0,0 +1 @@ +Improve the performance of outputting structured logging. diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 334ddaf39a..ffa7b20ca8 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -261,6 +261,18 @@ def parse_drain_configs( ) +class StoppableLogPublisher(LogPublisher): + """ + A log publisher that can tell its observers to shut down any external + communications. + """ + + def stop(self): + for obs in self._observers: + if hasattr(obs, "stop"): + obs.stop() + + def setup_structured_logging( hs, config, @@ -336,7 +348,7 @@ def setup_structured_logging( # We should never get here, but, just in case, throw an error. raise ConfigError("%s drain type cannot be configured" % (observer.type,)) - publisher = LogPublisher(*observers) + publisher = StoppableLogPublisher(*observers) log_filter = LogLevelFilterPredicate() for namespace, namespace_config in log_config.get( diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 76ce7d8808..05fc64f409 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -17,25 +17,29 @@ Log formatters that output terse JSON. """ +import json import sys +import traceback from collections import deque from ipaddress import IPv4Address, IPv6Address, ip_address from math import floor -from typing import IO +from typing import IO, Optional import attr -from simplejson import dumps from zope.interface import implementer from twisted.application.internet import ClientService +from twisted.internet.defer import Deferred from twisted.internet.endpoints import ( HostnameEndpoint, TCP4ClientEndpoint, TCP6ClientEndpoint, ) +from twisted.internet.interfaces import IPushProducer, ITransport from twisted.internet.protocol import Factory, Protocol from twisted.logger import FileLogObserver, ILogObserver, Logger -from twisted.python.failure import Failure + +_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":")) def flatten_event(event: dict, metadata: dict, include_time: bool = False): @@ -141,11 +145,49 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb def formatEvent(_event: dict) -> str: flattened = flatten_event(_event, metadata) - return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n" + return _encoder.encode(flattened) + "\n" return FileLogObserver(outFile, formatEvent) +@attr.s +@implementer(IPushProducer) +class LogProducer(object): + """ + An IPushProducer that writes logs from its buffer to its transport when it + is resumed. + + Args: + buffer: Log buffer to read logs from. + transport: Transport to write to. + """ + + transport = attr.ib(type=ITransport) + _buffer = attr.ib(type=deque) + _paused = attr.ib(default=False, type=bool, init=False) + + def pauseProducing(self): + self._paused = True + + def stopProducing(self): + self._paused = True + self._buffer = None + + def resumeProducing(self): + self._paused = False + + while self._paused is False and (self._buffer and self.transport.connected): + try: + event = self._buffer.popleft() + self.transport.write(_encoder.encode(event).encode("utf8")) + self.transport.write(b"\n") + except Exception: + # Something has gone wrong writing to the transport -- log it + # and break out of the while. + traceback.print_exc(file=sys.__stderr__) + break + + @attr.s @implementer(ILogObserver) class TerseJSONToTCPLogObserver(object): @@ -165,8 +207,9 @@ class TerseJSONToTCPLogObserver(object): metadata = attr.ib(type=dict) maximum_buffer = attr.ib(type=int) _buffer = attr.ib(default=attr.Factory(deque), type=deque) - _writer = attr.ib(default=None) + _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) _logger = attr.ib(default=attr.Factory(Logger)) + _producer = attr.ib(default=None, type=Optional[LogProducer]) def start(self) -> None: @@ -187,38 +230,43 @@ class TerseJSONToTCPLogObserver(object): factory = Factory.forProtocol(Protocol) self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) self._service.startService() + self._connect() - def _write_loop(self) -> None: + def stop(self): + self._service.stopService() + + def _connect(self) -> None: """ - Implement the write loop. + Triggers an attempt to connect then write to the remote if not already writing. """ - if self._writer: + if self._connection_waiter: return - self._writer = self._service.whenConnected() + self._connection_waiter = self._service.whenConnected(failAfterFailures=1) + + @self._connection_waiter.addErrback + def fail(r): + r.printTraceback(file=sys.__stderr__) + self._connection_waiter = None + self._connect() - @self._writer.addBoth + @self._connection_waiter.addCallback def writer(r): - if isinstance(r, Failure): - r.printTraceback(file=sys.__stderr__) - self._writer = None - self.hs.get_reactor().callLater(1, self._write_loop) + # We have a connection. If we already have a producer, and its + # transport is the same, just trigger a resumeProducing. + if self._producer and r.transport is self._producer.transport: + self._producer.resumeProducing() return - try: - for event in self._buffer: - r.transport.write( - dumps(event, ensure_ascii=False, separators=(",", ":")).encode( - "utf8" - ) - ) - r.transport.write(b"\n") - self._buffer.clear() - except Exception as e: - sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),)) - - self._writer = False - self.hs.get_reactor().callLater(1, self._write_loop) + # If the producer is still producing, stop it. + if self._producer: + self._producer.stopProducing() + + # Make a new producer and start it. + self._producer = LogProducer(buffer=self._buffer, transport=r.transport) + r.transport.registerProducer(self._producer, True) + self._producer.resumeProducing() + self._connection_waiter = None def _handle_pressure(self) -> None: """ @@ -277,4 +325,4 @@ class TerseJSONToTCPLogObserver(object): self._logger.failure("Failed clearing backpressure") # Try and write immediately. - self._write_loop() + self._connect() diff --git a/tests/server.py b/tests/server.py index f878aeaada..2b7cf4242e 100644 --- a/tests/server.py +++ b/tests/server.py @@ -379,6 +379,7 @@ class FakeTransport(object): disconnecting = False disconnected = False + connected = True buffer = attr.ib(default=b"") producer = attr.ib(default=None) autoflush = attr.ib(default=True) @@ -402,6 +403,7 @@ class FakeTransport(object): "FakeTransport: Delaying disconnect until buffer is flushed" ) else: + self.connected = False self.disconnected = True def abortConnection(self): -- cgit 1.5.1 From f0ef9708241ec65fe6f32f1ad36f719ab4ab2b53 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 26 Nov 2019 17:49:12 +0000 Subject: Don't restrict the tests to v1 rooms --- tests/rest/client/test_retention.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 9e549d8a91..95475bb651 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -34,7 +34,6 @@ class RetentionTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config["default_room_version"] = "1" config["retention"] = { "enabled": True, "default_policy": { @@ -204,7 +203,6 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config["default_room_version"] = "1" config["retention"] = { "enabled": True, } -- cgit 1.5.1 From ce578031f4d0fe6f1eb26de4cb3d30a4175468db Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 26 Nov 2019 18:42:27 +0000 Subject: Remove assertion and provide a clear warning on startup for missing public_baseurl (#6379) --- changelog.d/6379.misc | 1 + synapse/config/emailconfig.py | 2 ++ synapse/config/registration.py | 7 +++++++ tests/rest/client/v2_alpha/test_register.py | 1 + 4 files changed, 11 insertions(+) create mode 100644 changelog.d/6379.misc (limited to 'tests') diff --git a/changelog.d/6379.misc b/changelog.d/6379.misc new file mode 100644 index 0000000000..725c2e7d87 --- /dev/null +++ b/changelog.d/6379.misc @@ -0,0 +1 @@ +Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary. \ No newline at end of file diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 43fad0bf8b..ac1724045f 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -146,6 +146,8 @@ class EmailConfig(Config): if k not in email_config: missing.append("email." + k) + # public_baseurl is required to build password reset and validation links that + # will be emailed to users if config.get("public_baseurl") is None: missing.append("public_baseurl") diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 1f6dac69da..ee9614c5f7 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -106,6 +106,13 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") + if self.account_threepid_delegate_msisdn and not self.public_baseurl: + raise ConfigError( + "The configuration option `public_baseurl` is required if " + "`account_threepid_delegate.msisdn` is set, such that " + "clients know where to submit validation tokens to. Please " + "configure `public_baseurl`." + ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index dab87e5edf..c0d0d2b44e 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -203,6 +203,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @unittest.override_config( { + "public_baseurl": "https://test_server", "enable_registration_captcha": True, "user_consent": { "version": "1", -- cgit 1.5.1 From 0d27aba900136514a8801b902f9a8ac69150e2c0 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 27 Nov 2019 16:14:44 -0500 Subject: add etag and count to key backup endpoints (#5858) --- changelog.d/5858.feature | 1 + synapse/handlers/e2e_room_keys.py | 130 +++++++----- synapse/rest/client/v2_alpha/room_keys.py | 8 +- synapse/storage/data_stores/main/e2e_room_keys.py | 226 +++++++++++++++------ .../main/schema/delta/56/room_key_etag.sql | 17 ++ tests/handlers/test_e2e_room_keys.py | 31 +++ tests/storage/test_e2e_room_keys.py | 8 +- 7 files changed, 297 insertions(+), 124 deletions(-) create mode 100644 changelog.d/5858.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql (limited to 'tests') diff --git a/changelog.d/5858.feature b/changelog.d/5858.feature new file mode 100644 index 0000000000..55ee93051e --- /dev/null +++ b/changelog.d/5858.feature @@ -0,0 +1 @@ +Add etag and count fields to key backup endpoints to help clients guess if there are new keys. diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 0cea445f0d..f1b4424a02 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017, 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -103,14 +104,35 @@ class E2eRoomKeysHandler(object): rooms session_id(string): session ID to delete keys for, for None to delete keys for all sessions + Raises: + NotFoundError: if the backup version does not exist Returns: - A deferred of the deletion transaction + A dict containing the count and etag for the backup version """ # lock for consistency with uploading with (yield self._upload_linearizer.queue(user_id)): + # make sure the backup version exists + try: + version_info = yield self.store.get_e2e_room_keys_version_info( + user_id, version + ) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown backup version") + else: + raise + yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + version_etag = version_info["etag"] + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag + ) + + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} + @trace @defer.inlineCallbacks def upload_room_keys(self, user_id, version, room_keys): @@ -138,6 +160,9 @@ class E2eRoomKeysHandler(object): } } + Returns: + A dict containing the count and etag for the backup version + Raises: NotFoundError: if there are no versions defined RoomKeysVersionError: if the uploaded version is not the current version @@ -171,59 +196,62 @@ class E2eRoomKeysHandler(object): else: raise - # go through the room_keys. - # XXX: this should/could be done concurrently, given we're in a lock. + # Fetch any existing room keys for the sessions that have been + # submitted. Then compare them with the submitted keys. If the + # key is new, insert it; if the key should be updated, then update + # it; otherwise, drop it. + existing_keys = yield self.store.get_e2e_room_keys_multi( + user_id, version, room_keys["rooms"] + ) + to_insert = [] # batch the inserts together + changed = False # if anything has changed, we need to update the etag for room_id, room in iteritems(room_keys["rooms"]): - for session_id, session in iteritems(room["sessions"]): - yield self._upload_room_key( - user_id, version, room_id, session_id, session + for session_id, room_key in iteritems(room["sessions"]): + log_kv( + { + "message": "Trying to upload room key", + "room_id": room_id, + "session_id": session_id, + "user_id": user_id, + } ) - - @defer.inlineCallbacks - def _upload_room_key(self, user_id, version, room_id, session_id, room_key): - """Upload a given room_key for a given room and session into a given - version of the backup. Merges the key with any which might already exist. - - Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set - """ - log_kv( - { - "message": "Trying to upload room key", - "room_id": room_id, - "session_id": session_id, - "user_id": user_id, - } - ) - # get the room_key for this particular row - current_room_key = None - try: - current_room_key = yield self.store.get_e2e_room_key( - user_id, version, room_id, session_id - ) - except StoreError as e: - if e.code == 404: - log_kv( - { - "message": "Room key not found.", - "room_id": room_id, - "user_id": user_id, - } + current_room_key = existing_keys.get(room_id, {}).get(session_id) + if current_room_key: + if self._should_replace_room_key(current_room_key, room_key): + log_kv({"message": "Replacing room key."}) + # updates are done one at a time in the DB, so send + # updates right away rather than batching them up, + # like we do with the inserts + yield self.store.update_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) + changed = True + else: + log_kv({"message": "Not replacing room_key."}) + else: + log_kv( + { + "message": "Room key not found.", + "room_id": room_id, + "user_id": user_id, + } + ) + log_kv({"message": "Replacing room key."}) + to_insert.append((room_id, session_id, room_key)) + changed = True + + if len(to_insert): + yield self.store.add_e2e_room_keys(user_id, version, to_insert) + + version_etag = version_info["etag"] + if changed: + version_etag = version_etag + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag ) - else: - raise - if self._should_replace_room_key(current_room_key, room_key): - log_kv({"message": "Replacing room key."}) - yield self.store.set_e2e_room_key( - user_id, version, room_id, session_id, room_key - ) - else: - log_kv({"message": "Not replacing room_key."}) + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} @staticmethod def _should_replace_room_key(current_room_key, room_key): @@ -314,6 +342,8 @@ class E2eRoomKeysHandler(object): raise NotFoundError("Unknown backup version") else: raise + + res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) return res @trace diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index d596786430..d83ac8e3c5 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -134,8 +134,8 @@ class RoomKeysServlet(RestServlet): if room_id: body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return 200, {} + ret = yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + return 200, ret @defer.inlineCallbacks def on_GET(self, request, room_id, session_id): @@ -239,10 +239,10 @@ class RoomKeysServlet(RestServlet): user_id = requester.user.to_string() version = parse_string(request, "version") - yield self.e2e_room_keys_handler.delete_room_keys( + ret = yield self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) - return 200, {} + return 200, ret class RoomKeysNewVersionServlet(RestServlet): diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 1cbbae5b63..113224fd7c 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,49 +25,8 @@ from synapse.storage._base import SQLBaseStore class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks - def get_e2e_room_key(self, user_id, version, room_id, session_id): - """Get the encrypted E2E room key for a given session from a given - backup version of room_keys. We only store the 'best' room key for a given - session at a given time, as determined by the handler. - - Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): the ID of the room whose keys we're querying. - This is a bit redundant as it's implied by the session_id, but - we include for consistency with the rest of the API. - session_id(str): the session whose room_key we're querying. - - Returns: - A deferred dict giving the session_data and message metadata for - this room key. - """ - - row = yield self._simple_select_one( - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - }, - retcols=( - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", - ), - desc="get_e2e_room_key", - ) - - row["session_data"] = json.loads(row["session_data"]) - - return row - - @defer.inlineCallbacks - def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): - """Replaces or inserts the encrypted E2E room key for a given session in - a given backup + def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces the encrypted E2E room key for a given session in a given backup Args: user_id(str): the user whose backup we're setting @@ -78,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self._simple_upsert( + yield self._simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -86,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore): "room_id": room_id, "session_id": session_id, }, - values={ + updatevalues={ "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], "session_data": json.dumps(room_key["session_data"]), }, - lock=False, + desc="update_e2e_room_key", ) - log_kv( - { - "message": "Set room key", - "room_id": room_id, - "session_id": session_id, - "room_key": room_key, - } + + @defer.inlineCallbacks + def add_e2e_room_keys(self, user_id, version, room_keys): + """Bulk add room keys to a given backup. + + Args: + user_id (str): the user whose backup we're adding to + version (str): the version ID of the backup for the set of keys we're adding to + room_keys (iterable[(str, str, dict)]): the keys to add, in the form + (roomID, sessionID, keyData) + """ + + values = [] + for (room_id, session_id, room_key) in room_keys: + values.append( + { + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), + } + ) + log_kv( + { + "message": "Set room key", + "room_id": room_id, + "session_id": session_id, + "room_key": room_key, + } + ) + + yield self._simple_insert_many( + table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace @@ -110,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore): room, or a given session. Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): Optional. the ID of the room whose keys we're querying, if any. + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup for the set of keys we're querying + room_id (str): Optional. the ID of the room whose keys we're querying, if any. If not specified, we return the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. + session_id (str): Optional. the session whose room_key we're querying, if any. If specified, we also require the room_id to be specified. If not specified, we return all the keys in this version of the backup (or for the specified room) @@ -162,6 +152,95 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions + def get_e2e_room_keys_multi(self, user_id, version, room_keys): + """Get multiple room keys at a time. The difference between this function and + get_e2e_room_keys is that this function can be used to retrieve + multiple specific keys at a time, whereas get_e2e_room_keys is used for + getting all the keys in a backup version, all the keys for a room, or a + specific key. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + room_keys (dict[str, dict[str, iterable[str]]]): a map from + room ID -> {"session": [session ids]} indicating the session IDs + that we want to query + + Returns: + Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + """ + + return self.runInteraction( + "get_e2e_room_keys_multi", + self._get_e2e_room_keys_multi_txn, + user_id, + version, + room_keys, + ) + + @staticmethod + def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): + if not room_keys: + return {} + + where_clauses = [] + params = [user_id, version] + for room_id, room in room_keys.items(): + sessions = list(room["sessions"]) + if not sessions: + continue + params.append(room_id) + params.extend(sessions) + where_clauses.append( + "(room_id = ? AND session_id IN (%s))" + % (",".join(["?" for _ in sessions]),) + ) + + # check if we're actually querying something + if not where_clauses: + return {} + + sql = """ + SELECT room_id, session_id, first_message_index, forwarded_count, + is_verified, session_data + FROM e2e_room_keys + WHERE user_id = ? AND version = ? AND (%s) + """ % ( + " OR ".join(where_clauses) + ) + + txn.execute(sql, params) + + ret = {} + + for row in txn: + room_id = row[0] + session_id = row[1] + ret.setdefault(room_id, {}) + ret[room_id][session_id] = { + "first_message_index": row[2], + "forwarded_count": row[3], + "is_verified": row[4], + "session_data": json.loads(row[5]), + } + + return ret + + def count_e2e_room_keys(self, user_id, version): + """Get the number of keys in a backup version. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + """ + + return self._simple_select_one_onecol( + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": version}, + retcol="COUNT(*)", + desc="count_e2e_room_keys", + ) + @trace @defer.inlineCallbacks def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): @@ -219,6 +298,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version(str) algorithm(str) auth_data(object): opaque dict supplied by the client + etag(int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn): @@ -236,10 +316,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data"), + retcols=("version", "algorithm", "auth_data", "etag"), ) result["auth_data"] = json.loads(result["auth_data"]) result["version"] = str(result["version"]) + if result["etag"] is None: + result["etag"] = 0 return result return self.runInteraction( @@ -288,21 +370,33 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def update_e2e_room_keys_version(self, user_id, version, info): + def update_e2e_room_keys_version( + self, user_id, version, info=None, version_etag=None + ): """Update a given backup version Args: user_id(str): the user whose backup version we're updating version(str): the version ID of the backup version we're updating - info(dict): the new backup version info to store + info (dict): the new backup version info to store. If None, then + the backup version info is not updated + version_etag (Optional[int]): etag of the keys in the backup. If + None, then the etag is not updated """ + updatevalues = {} - return self._simple_update( - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": version}, - updatevalues={"auth_data": json.dumps(info["auth_data"])}, - desc="update_e2e_room_keys_version", - ) + if info is not None and "auth_data" in info: + updatevalues["auth_data"] = json.dumps(info["auth_data"]) + if version_etag is not None: + updatevalues["etag"] = version_etag + + if updatevalues: + return self._simple_update( + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": version}, + updatevalues=updatevalues, + desc="update_e2e_room_keys_version", + ) @trace def delete_e2e_room_keys_version(self, user_id, version=None): diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql new file mode 100644 index 0000000000..7d70dd071e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- store the current etag of backup version +ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT; diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 0bb96674a2..70f172eb02 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + version_etag = res["etag"] + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") + self.assertEqual(res["etag"], version_etag) + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) @@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "version": "2", "algorithm": "m.megolm_backup.v1", "auth_data": "second_version_auth_data", + "count": 0, }, ) @@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -207,12 +218,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -409,6 +422,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) + # get the etag to compare to future versions + res = yield self.handler.get_version_info(self.local_user) + backup_etag = res["etag"] + self.assertEqual(res["count"], 1) + new_room_keys = copy.deepcopy(room_keys) new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] @@ -423,6 +441,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): "SSBBTSBBIEZJU0gK", ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) @@ -432,6 +454,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should NOT be equal now, since the key changed + res = yield self.handler.get_version_info(self.local_user) + self.assertNotEqual(res["etag"], backup_etag) + backup_etag = res["etag"] + # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count new_room_key["forwarded_count"] = 2 @@ -443,6 +470,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # TODO: check edge cases as well as the common variations here @defer.inlineCallbacks diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index d128fde441..35dafbb904 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -39,8 +39,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.store.set_e2e_room_key( - "user_id", version1, "room", "session", room_key + self.store.add_e2e_room_keys( + "user_id", version1, [("room", "session", room_key)] ) ) @@ -51,8 +51,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.store.set_e2e_room_key( - "user_id", version2, "room", "session", room_key + self.store.add_e2e_room_keys( + "user_id", version2, [("room", "session", room_key)] ) ) -- cgit 1.5.1 From 0f87b912aba7e678041632bc9a6d1f7c2d24342c Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 28 Nov 2019 08:54:07 +1100 Subject: Implementation of MSC2314 (#6176) --- changelog.d/6176.feature | 1 + synapse/federation/federation_server.py | 26 ++++++++---- synapse/federation/transport/server.py | 6 +-- sytest-blacklist | 6 ++- tests/federation/test_complexity.py | 28 ++---------- tests/federation/test_federation_sender.py | 4 +- tests/federation/test_federation_server.py | 63 +++++++++++++++++++++++++++ tests/handlers/test_typing.py | 3 ++ tests/replication/slave/storage/_base.py | 3 ++ tests/replication/tcp/streams/_base.py | 4 ++ tests/storage/test_roommember.py | 26 +----------- tests/unittest.py | 68 +++++++++++++++++++++++++++++- tests/utils.py | 1 + 13 files changed, 174 insertions(+), 65 deletions(-) create mode 100644 changelog.d/6176.feature (limited to 'tests') diff --git a/changelog.d/6176.feature b/changelog.d/6176.feature new file mode 100644 index 0000000000..3c66d689d4 --- /dev/null +++ b/changelog.d/6176.feature @@ -0,0 +1 @@ +Implement the `/_matrix/federation/unstable/net.atleastfornow/state/` API as drafted in MSC2314. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d942d77a72..84d4eca041 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,6 +74,7 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler + self.state = hs.get_state_handler() self._server_linearizer = Linearizer("fed_server") self._transaction_linearizer = Linearizer("fed_txn_handler") @@ -264,9 +266,6 @@ class FederationServer(FederationBase): await self.registry.on_edu(edu_type, origin, content) async def on_context_state_request(self, origin, room_id, event_id): - if not event_id: - raise NotImplementedError("Specify an event") - origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -280,13 +279,18 @@ class FederationServer(FederationBase): # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. with (await self._server_linearizer.queue((origin, room_id))): - resp = await self._state_resp_cache.wrap( - (room_id, event_id), - self._on_context_state_request_compute, - room_id, - event_id, + resp = dict( + await self._state_resp_cache.wrap( + (room_id, event_id), + self._on_context_state_request_compute, + room_id, + event_id, + ) ) + room_version = await self.store.get_room_version(room_id) + resp["room_version"] = room_version + return 200, resp async def on_state_ids_request(self, origin, room_id, event_id): @@ -306,7 +310,11 @@ class FederationServer(FederationBase): return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute(self, room_id, event_id): - pdus = await self.handler.get_state_for_pdu(room_id, event_id) + if event_id: + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + else: + pdus = (await self.state.get_current_state(room_id)).values() + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 09baa9c57d..fefc789c85 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -421,7 +421,7 @@ class FederationEventServlet(BaseFederationServlet): return await self.handler.on_pdu_request(origin, event_id) -class FederationStateServlet(BaseFederationServlet): +class FederationStateV1Servlet(BaseFederationServlet): PATH = "/state/(?P[^/]*)/?" # This is when someone asks for all data for a given context. @@ -429,7 +429,7 @@ class FederationStateServlet(BaseFederationServlet): return await self.handler.on_context_state_request( origin, context, - parse_string_from_args(query, "event_id", None, required=True), + parse_string_from_args(query, "event_id", None, required=False), ) @@ -1360,7 +1360,7 @@ class RoomComplexityServlet(BaseFederationServlet): FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationEventServlet, - FederationStateServlet, + FederationStateV1Servlet, FederationStateIdsServlet, FederationBackfillServlet, FederationQueryServlet, diff --git a/sytest-blacklist b/sytest-blacklist index 11785fd43f..411cce0692 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,6 +1,6 @@ # This file serves as a blacklist for SyTest tests that we expect will fail in # Synapse. -# +# # Each line of this file is scanned by sytest during a run and if the line # exactly matches the name of a test, it will be marked as "expected fail", # meaning the test will still run, but failure will not mark the entire test @@ -29,3 +29,7 @@ Enabling an unknown default rule fails with 404 # Blacklisted due to https://github.com/matrix-org/synapse/issues/1663 New federated private chats get full presence information (SYN-115) + +# Blacklisted due to https://github.com/matrix-org/matrix-doc/pull/2314 removing +# this requirement from the spec +Inbound federation of state requires event_id as a mandatory paramater diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 51714a2b06..24fa8dbb45 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -18,17 +18,14 @@ from mock import Mock from twisted.internet import defer from synapse.api.errors import Codes, SynapseError -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.federation.transport import server from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import UserID -from synapse.util.ratelimitutils import FederationRateLimiter from tests import unittest -class RoomComplexityTests(unittest.HomeserverTestCase): +class RoomComplexityTests(unittest.FederatingHomeserverTestCase): servlets = [ admin.register_servlets, @@ -41,25 +38,6 @@ class RoomComplexityTests(unittest.HomeserverTestCase): config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config - def prepare(self, reactor, clock, homeserver): - class Authenticator(object): - def authenticate_request(self, request, content): - return defer.succeed("otherserver.nottld") - - ratelimiter = FederationRateLimiter( - clock, - FederationRateLimitConfig( - window_size=1, - sleep_limit=1, - sleep_msec=1, - reject_limit=1000, - concurrent_requests=1000, - ), - ) - server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - def test_complexity_simple(self): u1 = self.register_user("u1", "pass") @@ -105,7 +83,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): d = handler._remote_join( None, - ["otherserver.example"], + ["other.example.com"], "roomid", UserID.from_string(u1), {"membership": "join"}, @@ -146,7 +124,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): d = handler._remote_join( None, - ["otherserver.example"], + ["other.example.com"], room_1, UserID.from_string(u1), {"membership": "join"}, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index cce8d8c6de..d456267b87 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.types import ReadReceipt -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class FederationSenderTestCases(HomeserverTestCase): @@ -29,6 +29,7 @@ class FederationSenderTestCases(HomeserverTestCase): federation_transport_client=Mock(spec=["send_transaction"]), ) + @override_config({"send_federation": True}) def test_send_receipts(self): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] @@ -69,6 +70,7 @@ class FederationSenderTestCases(HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index b08be451aa..1ec8c40901 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +17,8 @@ import logging from synapse.events import FrozenEvent from synapse.federation.federation_server import server_matches_acl_event +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from tests import unittest @@ -41,6 +44,66 @@ class ServerACLsTestCase(unittest.TestCase): self.assertTrue(server_matches_acl_event("1:2:3:4", e)) +class StateQueryTests(unittest.FederatingHomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def test_without_event_id(self): + """ + Querying v1/state/ without an event ID will return the current + known state. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.inject_room_member(room_1, "@user:other.example.com", "join") + + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + self.assertEqual( + channel.json_body["room_version"], + self.hs.config.default_room_version.identifier, + ) + + members = set( + map( + lambda x: x["state_key"], + filter( + lambda x: x["type"] == "m.room.member", channel.json_body["pdus"] + ), + ) + ) + + self.assertEqual(members, set(["@user:other.example.com", u1])) + self.assertEqual(len(channel.json_body["pdus"]), 6) + + def test_needs_to_be_in_room(self): + """ + Querying v1/state/ requires the server + be in the room to provide data. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.render(request) + self.assertEquals(403, channel.code, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + + def _create_acl_event(content): return FrozenEvent( { diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5ec568f4e6..f6d8660285 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,6 +24,7 @@ from synapse.api.errors import AuthError from synapse.types import UserID from tests import unittest +from tests.unittest import override_config from tests.utils import register_federation_servlets # Some local users to test with @@ -174,6 +175,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] @@ -237,6 +239,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_stopped_typing(self): self.room_members = [U_APPLE, U_BANANA, U_ONION] diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 4f924ce451..e7472e3a93 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -48,7 +48,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = server_factory.streamer + handler_factory = Mock() self.replication_handler = ReplicationClientHandler(self.slaved_store) + self.replication_handler.factory = handler_factory + client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index ce3835ae6a..1d14e77255 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from synapse.replication.tcp.commands import ReplicateCommand from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -30,7 +32,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): server = server_factory.buildProtocol(None) # build a replication client, with a dummy handler + handler_factory = Mock() self.test_handler = TestReplicationClientHandler() + self.test_handler.factory = handler_factory self.client = ClientReplicationStreamProtocol( "client", "test", clock, self.test_handler ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 9ddd17f73d..105a0c2b02 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -16,8 +16,7 @@ from unittest.mock import Mock -from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import RoomVersions +from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room from synapse.types import Requester, UserID @@ -44,9 +43,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() - self.storage = hs.get_storage() - self.event_builder_factory = hs.get_event_builder_factory() - self.event_creation_handler = hs.get_event_creation_handler() self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -55,26 +51,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # User elsewhere on another host self.u_charlie = UserID.from_string("@charlie:elsewhere") - def inject_room_member(self, room, user, membership, replaces_state=None): - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - - self.get_success(self.storage.persistence.persist_event(event, context)) - - return event - def test_one_member(self): # Alice creates the room, and is automatically joined diff --git a/tests/unittest.py b/tests/unittest.py index 561cebc223..31997a0f31 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import hashlib import hmac @@ -27,13 +29,17 @@ from twisted.internet.defer import Deferred, succeed from twisted.python.threadpool import ThreadPool from twisted.trial import unittest -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.homeserver import HomeServerConfig +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester +from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import get_clock, make_request, render, setup_test_homeserver from tests.test_utils.logging_setup import setup_logging @@ -559,6 +565,66 @@ class HomeserverTestCase(TestCase): self.render(request) self.assertEqual(channel.code, 403, channel.result) + def inject_room_member(self, room: str, user: str, membership: Membership) -> None: + """ + Inject a membership event into a room. + + Args: + room: Room ID to inject the event into. + user: MXID of the user to inject the membership for. + membership: The membership type. + """ + event_builder_factory = self.hs.get_event_builder_factory() + event_creation_handler = self.hs.get_event_creation_handler() + + room_version = self.get_success(self.hs.get_datastore().get_room_version(room)) + + builder = event_builder_factory.for_room_version( + KNOWN_ROOM_VERSIONS[room_version], + { + "type": EventTypes.Member, + "sender": user, + "state_key": user, + "room_id": room, + "content": {"membership": membership}, + }, + ) + + event, context = self.get_success( + event_creation_handler.create_new_client_event(builder) + ) + + self.get_success( + self.hs.get_storage().persistence.persist_event(event, context) + ) + + +class FederatingHomeserverTestCase(HomeserverTestCase): + """ + A federating homeserver that authenticates incoming requests as `other.example.com`. + """ + + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return succeed("other.example.com") + + ratelimiter = FederationRateLimiter( + clock, + FederationRateLimitConfig( + window_size=1, + sleep_limit=1, + sleep_msec=1, + reject_limit=1000, + concurrent_requests=1000, + ), + ) + federation_server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + return super().prepare(reactor, clock, homeserver) + def override_config(extra_config): """A decorator which can be applied to test functions to give additional HS config diff --git a/tests/utils.py b/tests/utils.py index 7dc9bdc505..de2ac1ed33 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -109,6 +109,7 @@ def default_config(name, parse=False): """ config_dict = { "server_name": name, + "send_federation": False, "media_store_path": "media", "uploads_path": "uploads", # the test signing key is just an arbitrary ed25519 key to keep the config -- cgit 1.5.1 From 8c9a713f8db1d6fcc1f876ac6fbd0e54b5e5819c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Nov 2019 11:32:06 +0000 Subject: Add tests --- tests/rest/client/v1/test_rooms.py | 140 +++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e84e578f99..eda2fabc71 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1180,3 +1180,143 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) + + +class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): + """Tests that clients can add a "reason" field to membership events and + that they get correctly added to the generated events and propagated. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + def test_join_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_leave_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_kick_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_ban_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_unban_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_invite_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_reject_invite_reason(self): + self.helper.invite( + self.room_id, + src=self.creator, + targ=self.second_user_id, + tok=self.creator_tok, + ) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def _check_for_reason(self, reason): + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( + self.room_id, self.second_user_id + ), + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + event_content = channel.json_body + + self.assertEqual(event_content.get("reason"), reason, channel.result) -- cgit 1.5.1 From 54dd5dc12b0ac5c48303144c4a73ce3822209488 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 3 Dec 2019 19:19:45 +0000 Subject: Add ephemeral messages support (MSC2228) (#6409) Implement part [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). The parts that differ are: * the feature is hidden behind a configuration flag (`enable_ephemeral_messages`) * self-destruction doesn't happen for state events * only implement support for the `m.self_destruct_after` field (not the `m.self_destruct` one) * doesn't send synthetic redactions to clients because for this specific case we consider the clients to be able to destroy an event themselves, instead we just censor it (by pruning its JSON) in the database --- changelog.d/6409.feature | 1 + synapse/api/constants.py | 4 + synapse/config/server.py | 2 + synapse/handlers/federation.py | 8 ++ synapse/handlers/message.py | 123 +++++++++++++++++++- synapse/storage/data_stores/main/events.py | 126 ++++++++++++++++++++- .../main/schema/delta/56/event_expiry.sql | 21 ++++ tests/rest/client/test_ephemeral_message.py | 101 +++++++++++++++++ 8 files changed, 379 insertions(+), 7 deletions(-) create mode 100644 changelog.d/6409.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql create mode 100644 tests/rest/client/test_ephemeral_message.py (limited to 'tests') diff --git a/changelog.d/6409.feature b/changelog.d/6409.feature new file mode 100644 index 0000000000..653ff5a5ad --- /dev/null +++ b/changelog.d/6409.feature @@ -0,0 +1 @@ +Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index e3f086f1c3..69cef369a5 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -147,3 +147,7 @@ class EventContentFields(object): # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 LABELS = "org.matrix.labels" + + # Timestamp to delete the event after + # cf https://github.com/matrix-org/matrix-doc/pull/2228 + SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" diff --git a/synapse/config/server.py b/synapse/config/server.py index 7a9d711669..837fbe1582 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -490,6 +490,8 @@ class ServerConfig(Config): "cleanup_extremities_with_dummy_events", True ) + self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) + def has_tls_listener(self) -> bool: return any(l["tls"] for l in self.listeners) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d3267734f7..d9d0cd9eef 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -121,6 +121,7 @@ class FederationHandler(BaseHandler): self.pusher_pool = hs.get_pusherpool() self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() @@ -141,6 +142,8 @@ class FederationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or @@ -2715,6 +2718,11 @@ class FederationHandler(BaseHandler): event_and_contexts, backfilled=backfilled ) + if self._ephemeral_messages_enabled: + for (event, context) in event_and_contexts: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: yield self._notify_persisted_event(event, max_stream_id) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3b0156f516..4f53a5f5dc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from six import iteritems, itervalues, string_types @@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed +from twisted.internet.interfaces import IDelayedCall from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -62,6 +70,17 @@ class MessageHandler(object): self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._is_worker_app = bool(hs.config.worker_app) + + # The scheduled call to self._expire_event. None if no call is currently + # scheduled. + self._scheduled_expiry = None # type: Optional[IDelayedCall] + + if not hs.config.worker_app: + run_as_background_process( + "_schedule_next_expiry", self._schedule_next_expiry + ) @defer.inlineCallbacks def get_room_data( @@ -225,6 +244,100 @@ class MessageHandler(object): for user_id, profile in iteritems(users_with_profile) } + def maybe_schedule_expiry(self, event): + """Schedule the expiry of an event if there's not already one scheduled, + or if the one running is for an event that will expire after the provided + timestamp. + + This function needs to invalidate the event cache, which is only possible on + the master process, and therefore needs to be run on there. + + Args: + event (EventBase): The event to schedule the expiry of. + """ + assert not self._is_worker_app + + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if not isinstance(expiry_ts, int) or event.is_state(): + return + + # _schedule_expiry_for_event won't actually schedule anything if there's already + # a task scheduled for a timestamp that's sooner than the provided one. + self._schedule_expiry_for_event(event.event_id, expiry_ts) + + @defer.inlineCallbacks + def _schedule_next_expiry(self): + """Retrieve the ID and the expiry timestamp of the next event to be expired, + and schedule an expiry task for it. + + If there's no event left to expire, set _expiry_scheduled to None so that a + future call to save_expiry_ts can schedule a new expiry task. + """ + # Try to get the expiry timestamp of the next event to expire. + res = yield self.store.get_next_event_to_expire() + if res: + event_id, expiry_ts = res + self._schedule_expiry_for_event(event_id, expiry_ts) + + def _schedule_expiry_for_event(self, event_id, expiry_ts): + """Schedule an expiry task for the provided event if there's not already one + scheduled at a timestamp that's sooner than the provided one. + + Args: + event_id (str): The ID of the event to expire. + expiry_ts (int): The timestamp at which to expire the event. + """ + if self._scheduled_expiry: + # If the provided timestamp refers to a time before the scheduled time of the + # next expiry task, cancel that task and reschedule it for this timestamp. + next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000 + if expiry_ts < next_scheduled_expiry_ts: + self._scheduled_expiry.cancel() + else: + return + + # Figure out how many seconds we need to wait before expiring the event. + now_ms = self.clock.time_msec() + delay = (expiry_ts - now_ms) / 1000 + + # callLater doesn't support negative delays, so trim the delay to 0 if we're + # in that case. + if delay < 0: + delay = 0 + + logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay) + + self._scheduled_expiry = self.clock.call_later( + delay, + run_as_background_process, + "_expire_event", + self._expire_event, + event_id, + ) + + @defer.inlineCallbacks + def _expire_event(self, event_id): + """Retrieve and expire an event that needs to be expired from the database. + + If the event doesn't exist in the database, log it and delete the expiry date + from the database (so that we don't try to expire it again). + """ + assert self._ephemeral_events_enabled + + self._scheduled_expiry = None + + logger.info("Expiring event %s", event_id) + + try: + # Expire the event if we know about it. This function also deletes the expiry + # date from the database in the same database transaction. + yield self.store.expire_event(event_id) + except Exception as e: + logger.error("Could not expire event %s: %r", event_id, e) + + # Schedule the expiry of the next event to expire. + yield self._schedule_next_expiry() + # The duration (in ms) after which rooms should be removed # `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try @@ -295,6 +408,10 @@ class EventCreationHandler(object): 5 * 60 * 1000, ) + self._message_handler = hs.get_message_handler() + + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def create_event( self, @@ -877,6 +994,10 @@ class EventCreationHandler(object): event, context=context ) + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 2737a1d3ae..79c91fe284 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -130,6 +130,8 @@ class EventsStore( if self.hs.config.redaction_retention_period is not None: hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def _read_forward_extremities(self): def fetch(txn): @@ -940,6 +942,12 @@ class EventsStore( txn, event.event_id, labels, event.room_id, event.depth ) + if self._ephemeral_messages_enabled: + # If there's an expiry timestamp on the event, store it. + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if isinstance(expiry_ts, int) and not event.is_state(): + self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -1101,12 +1109,7 @@ class EventsStore( def _update_censor_txn(txn): for redaction_id, event_id, pruned_json in updates: if pruned_json: - self._simple_update_one_txn( - txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, - ) + self._censor_event_txn(txn, event_id, pruned_json) self._simple_update_one_txn( txn, @@ -1117,6 +1120,22 @@ class EventsStore( yield self.runInteraction("_update_censor_txn", _update_censor_txn) + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self._simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + @defer.inlineCallbacks def count_daily_messages(self): """ @@ -1957,6 +1976,101 @@ class EventsStore( ], ) + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. + + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. + """ + return self._simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json(prune_event_dict(event.get_dict())) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.runInteraction("delete_expired_event", delete_expired_event_txn) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self._simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql new file mode 100644 index 0000000000..81a36a8b1d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_expiry ( + event_id TEXT PRIMARY KEY, + expiry_ts BIGINT NOT NULL +); + +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py new file mode 100644 index 0000000000..5e9c07ebf3 --- /dev/null +++ b/tests/rest/client/test_ephemeral_message.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventContentFields, EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import room + +from tests import unittest + + +class EphemeralMessageTestCase(unittest.HomeserverTestCase): + + user_id = "@user:test" + + servlets = [ + admin.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["enable_ephemeral_messages"] = True + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_message_expiry_no_delay(self): + """Tests that sending a message sent with a m.self_destruct_after field set to the + past results in that event being deleted right away. + """ + # Send a message in the room that has expired. From here, the reactor clock is + # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock + # is at 0ms the code path is the same if the event's expiry timestamp is the + # current timestamp. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: 0, + }, + ) + event_id = res["event_id"] + + # Check that we can't retrieve the content of the event. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def test_message_expiry_delay(self): + """Tests that sending a message with a m.self_destruct_after field set to the + future results in that event not being deleted right away, but advancing the + clock to after that expiry timestamp causes the event to be deleted. + """ + # Send a message in the room that'll expire in 1s. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000, + }, + ) + event_id = res["event_id"] + + # Check that we can retrieve the content of the event before it has expired. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertTrue(bool(event_content), event_content) + + # Advance the clock to after the deletion. + self.reactor.advance(1) + + # Check that we can't retrieve the content of the event anymore. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body -- cgit 1.5.1 From cb0aeb147e3b3defc27866ad0e4982e63600a7ee Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 4 Dec 2019 09:46:16 +0000 Subject: privacy by default for room dir (#6355) Ensure that the the default settings for the room directory are that the it is hidden from public view by default. --- UPGRADE.rst | 17 ++++++++++ changelog.d/6354.feature | 1 + docs/sample_config.yaml | 13 ++++---- synapse/config/server.py | 26 +++++++++------- tests/federation/transport/test_server.py | 52 +++++++++++++++++++++++++++++++ 5 files changed, 91 insertions(+), 18 deletions(-) create mode 100644 changelog.d/6354.feature create mode 100644 tests/federation/transport/test_server.py (limited to 'tests') diff --git a/UPGRADE.rst b/UPGRADE.rst index 5ebf16a73e..d9020f2663 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,23 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.7.0 +=================== + +In an attempt to configure Synapse in a privacy preserving way, the default +behaviours of ``allow_public_rooms_without_auth`` and +``allow_public_rooms_over_federation`` have been inverted. This means that by +default, only authenticated users querying the Client/Server API will be able +to query the room directory, and relatedly that the server will not share +room directory information with other servers over federation. + +If your installation does not explicitly set these settings one way or the other +and you want either setting to be ``true`` then it will necessary to update +your homeserver configuration file accordingly. + +For more details on the surrounding context see our `explainer +`_. + Upgrading to v1.5.0 =================== diff --git a/changelog.d/6354.feature b/changelog.d/6354.feature new file mode 100644 index 0000000000..fed9db884b --- /dev/null +++ b/changelog.d/6354.feature @@ -0,0 +1 @@ +Configure privacy preserving settings by default for the room directory. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c7391f0c48..10664ae8f7 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -54,15 +54,16 @@ pid_file: DATADIR/homeserver.pid # #require_auth_for_profile_requests: true -# If set to 'false', requires authentication to access the server's public rooms -# directory through the client API. Defaults to 'true'. +# If set to 'true', removes the need for authentication to access the server's +# public rooms directory through the client API, meaning that anyone can +# query the room directory. Defaults to 'false'. # -#allow_public_rooms_without_auth: false +#allow_public_rooms_without_auth: true -# If set to 'false', forbids any other homeserver to fetch the server's public -# rooms directory via federation. Defaults to 'true'. +# If set to 'true', allows any other homeserver to fetch the server's public +# rooms directory via federation. Defaults to 'false'. # -#allow_public_rooms_over_federation: false +#allow_public_rooms_over_federation: true # The default room version for newly created rooms. # diff --git a/synapse/config/server.py b/synapse/config/server.py index 837fbe1582..a4bef00936 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -118,15 +118,16 @@ class ServerConfig(Config): self.allow_public_rooms_without_auth = False self.allow_public_rooms_over_federation = False else: - # If set to 'False', requires authentication to access the server's public - # rooms directory through the client API. Defaults to 'True'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. self.allow_public_rooms_without_auth = config.get( - "allow_public_rooms_without_auth", True + "allow_public_rooms_without_auth", False ) - # If set to 'False', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'True'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. self.allow_public_rooms_over_federation = config.get( - "allow_public_rooms_over_federation", True + "allow_public_rooms_over_federation", False ) default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION) @@ -620,15 +621,16 @@ class ServerConfig(Config): # #require_auth_for_profile_requests: true - # If set to 'false', requires authentication to access the server's public rooms - # directory through the client API. Defaults to 'true'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. # - #allow_public_rooms_without_auth: false + #allow_public_rooms_without_auth: true - # If set to 'false', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'true'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. # - #allow_public_rooms_over_federation: false + #allow_public_rooms_over_federation: true # The default room version for newly created rooms. # diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py new file mode 100644 index 0000000000..27d83bb7d9 --- /dev/null +++ b/tests/federation/transport/test_server.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest +from tests.unittest import override_config + + +class RoomDirectoryFederationTests(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return defer.succeed("otherserver.nottld") + + ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig()) + server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + @override_config({"allow_public_rooms_over_federation": False}) + def test_blocked_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(403, channel.code) + + @override_config({"allow_public_rooms_over_federation": True}) + def test_open_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(200, channel.code) -- cgit 1.5.1 From 65c6aee621fecff1c6a863d6b910c973196ad6bc Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 4 Dec 2019 14:36:39 +0000 Subject: Un-remove room purge test --- tests/rest/client/v1/test_rooms.py | 72 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 4095e63aef..1ca7fa742f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -815,6 +815,78 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) + def test_room_messages_purge(self): + store = self.hs.get_datastore() + pagination_handler = self.hs.get_pagination_handler() + + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) + ) + + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) + ) + + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, "message 3") + + # Check that we get the first and second message when querying /messages. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) + + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token, + delete_local_events=True, + ) + ) + + # Check that we only get the second message through /message now that the first + # has been purged. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) + + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From ee86abb2d6e9c7d553858e814b4343bcf95af75a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Dec 2019 10:15:55 +0000 Subject: Remove underscore from SQLBaseStore functions --- scripts-dev/hash_history.py | 2 +- scripts/synapse_port_db | 18 +-- synapse/app/user_dir.py | 2 +- synapse/storage/_base.py | 142 ++++++++++----------- synapse/storage/background_updates.py | 14 +- synapse/storage/data_stores/main/__init__.py | 16 +-- synapse/storage/data_stores/main/account_data.py | 18 +-- synapse/storage/data_stores/main/appservice.py | 10 +- synapse/storage/data_stores/main/cache.py | 2 +- synapse/storage/data_stores/main/client_ips.py | 8 +- synapse/storage/data_stores/main/deviceinbox.py | 4 +- synapse/storage/data_stores/main/devices.py | 46 +++---- synapse/storage/data_stores/main/directory.py | 12 +- synapse/storage/data_stores/main/e2e_room_keys.py | 20 +-- .../storage/data_stores/main/end_to_end_keys.py | 20 +-- .../storage/data_stores/main/event_federation.py | 16 +-- .../storage/data_stores/main/event_push_actions.py | 12 +- synapse/storage/data_stores/main/events.py | 52 ++++---- .../storage/data_stores/main/events_bg_updates.py | 10 +- synapse/storage/data_stores/main/events_worker.py | 6 +- synapse/storage/data_stores/main/filtering.py | 2 +- synapse/storage/data_stores/main/group_server.py | 128 +++++++++---------- synapse/storage/data_stores/main/keys.py | 6 +- .../storage/data_stores/main/media_repository.py | 24 ++-- .../data_stores/main/monthly_active_users.py | 6 +- synapse/storage/data_stores/main/openid.py | 2 +- synapse/storage/data_stores/main/presence.py | 8 +- synapse/storage/data_stores/main/profile.py | 24 ++-- synapse/storage/data_stores/main/push_rule.py | 24 ++-- synapse/storage/data_stores/main/pusher.py | 26 ++-- synapse/storage/data_stores/main/receipts.py | 16 +-- synapse/storage/data_stores/main/registration.py | 92 ++++++------- synapse/storage/data_stores/main/rejections.py | 4 +- synapse/storage/data_stores/main/relations.py | 4 +- synapse/storage/data_stores/main/room.py | 34 ++--- synapse/storage/data_stores/main/roommember.py | 16 +-- synapse/storage/data_stores/main/search.py | 8 +- synapse/storage/data_stores/main/signatures.py | 2 +- synapse/storage/data_stores/main/state.py | 36 +++--- synapse/storage/data_stores/main/state_deltas.py | 2 +- synapse/storage/data_stores/main/stats.py | 28 ++-- synapse/storage/data_stores/main/stream.py | 14 +- synapse/storage/data_stores/main/tags.py | 6 +- synapse/storage/data_stores/main/transactions.py | 12 +- synapse/storage/data_stores/main/user_directory.py | 56 ++++---- .../storage/data_stores/main/user_erasure_store.py | 4 +- tests/handlers/test_stats.py | 30 ++--- tests/handlers/test_user_directory.py | 12 +- tests/rest/admin/test_admin.py | 2 +- tests/storage/test__base.py | 8 +- tests/storage/test_base.py | 18 +-- tests/storage/test_client_ips.py | 12 +- tests/storage/test_event_push_actions.py | 4 +- tests/storage/test_redaction.py | 4 +- tests/storage/test_roommember.py | 2 +- tests/unittest.py | 2 +- 56 files changed, 550 insertions(+), 558 deletions(-) (limited to 'tests') diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index d20f6db176..bf3862a386 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -27,7 +27,7 @@ class Store(object): "_store_pdu_reference_hash_txn" ] _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] - _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] + simple_insert_txn = SQLBaseStore.__dict__["simple_insert_txn"] store = Store() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index f24b8ffe67..9dd1700ff0 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -221,7 +221,7 @@ class Porter(object): def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = yield self.postgres_store._simple_select_one( + row = yield self.postgres_store.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -236,7 +236,7 @@ class Porter(object): ) backward_chunk = 0 else: - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -266,7 +266,7 @@ class Porter(object): yield self.postgres_store.execute(delete_all) - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -320,7 +320,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -375,7 +375,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store._simple_update_one_txn( + self.postgres_store.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -452,7 +452,7 @@ class Porter(object): ], ) - self.postgres_store._simple_update_one_txn( + self.postgres_store.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -591,11 +591,11 @@ class Porter(object): # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = yield self.sqlite_store._simple_select_onecol( + sqlite_tables = yield self.sqlite_store.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = yield self.postgres_store._simple_select_onecol( + postgres_tables = yield self.postgres_store.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -722,7 +722,7 @@ class Porter(object): next_chunk = yield self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 6cb100319f..0fa2b50999 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -64,7 +64,7 @@ class UserDirectorySlaveStore( super(UserDirectorySlaveStore, self).__init__(db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1ed89d9f2a..9205e550bb 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -262,7 +262,7 @@ class SQLBaseStore(object): If the background updates have not completed, wait 15 sec and check again. """ - updates = yield self._simple_select_list( + updates = yield self.simple_select_list( "background_updates", keyvalues=None, retcols=["update_name"], @@ -307,7 +307,7 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) - def _new_transaction( + def new_transaction( self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs ): start = monotonic_time() @@ -444,7 +444,7 @@ class SQLBaseStore(object): try: result = yield self.runWithConnection( - self._new_transaction, + self.new_transaction, desc, after_callbacks, exception_callbacks, @@ -516,7 +516,7 @@ class SQLBaseStore(object): results = list(dict(zip(col_headers, row)) for row in cursor) return results - def _execute(self, desc, decoder, query, *args): + def execute(self, desc, decoder, query, *args): """Runs a single query for a result set. Args: @@ -541,7 +541,7 @@ class SQLBaseStore(object): # no complex WHERE clauses, just a dict of values for columns. @defer.inlineCallbacks - def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): + def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): """Executes an INSERT query on the named table. Args: @@ -557,7 +557,7 @@ class SQLBaseStore(object): `or_ignore` is True """ try: - yield self.runInteraction(desc, self._simple_insert_txn, table, values) + yield self.runInteraction(desc, self.simple_insert_txn, table, values) except self.database_engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -567,7 +567,7 @@ class SQLBaseStore(object): return True @staticmethod - def _simple_insert_txn(txn, table, values): + def simple_insert_txn(txn, table, values): keys, vals = zip(*values.items()) sql = "INSERT INTO %s (%s) VALUES(%s)" % ( @@ -578,11 +578,11 @@ class SQLBaseStore(object): txn.execute(sql, vals) - def _simple_insert_many(self, table, values, desc): - return self.runInteraction(desc, self._simple_insert_many_txn, table, values) + def simple_insert_many(self, table, values, desc): + return self.runInteraction(desc, self.simple_insert_many_txn, table, values) @staticmethod - def _simple_insert_many_txn(txn, table, values): + def simple_insert_many_txn(txn, table, values): if not values: return @@ -611,13 +611,13 @@ class SQLBaseStore(object): txn.executemany(sql, vals) @defer.inlineCallbacks - def _simple_upsert( + def simple_upsert( self, table, keyvalues, values, insertion_values={}, - desc="_simple_upsert", + desc="simple_upsert", lock=True, ): """ @@ -649,7 +649,7 @@ class SQLBaseStore(object): try: result = yield self.runInteraction( desc, - self._simple_upsert_txn, + self.simple_upsert_txn, table, keyvalues, values, @@ -669,7 +669,7 @@ class SQLBaseStore(object): "IntegrityError when upserting into %s; retrying: %s", table, e ) - def _simple_upsert_txn( + def simple_upsert_txn( self, txn, table, keyvalues, values, insertion_values={}, lock=True ): """ @@ -693,11 +693,11 @@ class SQLBaseStore(object): self.database_engine.can_native_upsert and table not in self._unsafe_to_upsert_tables ): - return self._simple_upsert_txn_native_upsert( + return self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) else: - return self._simple_upsert_txn_emulated( + return self.simple_upsert_txn_emulated( txn, table, keyvalues, @@ -706,7 +706,7 @@ class SQLBaseStore(object): lock=lock, ) - def _simple_upsert_txn_emulated( + def simple_upsert_txn_emulated( self, txn, table, keyvalues, values, insertion_values={}, lock=True ): """ @@ -775,7 +775,7 @@ class SQLBaseStore(object): # successfully inserted return True - def _simple_upsert_txn_native_upsert( + def simple_upsert_txn_native_upsert( self, txn, table, keyvalues, values, insertion_values={} ): """ @@ -809,7 +809,7 @@ class SQLBaseStore(object): ) txn.execute(sql, list(allvalues.values())) - def _simple_upsert_many_txn( + def simple_upsert_many_txn( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -829,15 +829,15 @@ class SQLBaseStore(object): self.database_engine.can_native_upsert and table not in self._unsafe_to_upsert_tables ): - return self._simple_upsert_many_txn_native_upsert( + return self.simple_upsert_many_txn_native_upsert( txn, table, key_names, key_values, value_names, value_values ) else: - return self._simple_upsert_many_txn_emulated( + return self.simple_upsert_many_txn_emulated( txn, table, key_names, key_values, value_names, value_values ) - def _simple_upsert_many_txn_emulated( + def simple_upsert_many_txn_emulated( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -862,9 +862,9 @@ class SQLBaseStore(object): _keys = {x: y for x, y in zip(key_names, keyv)} _vals = {x: y for x, y in zip(value_names, valv)} - self._simple_upsert_txn_emulated(txn, table, _keys, _vals) + self.simple_upsert_txn_emulated(txn, table, _keys, _vals) - def _simple_upsert_many_txn_native_upsert( + def simple_upsert_many_txn_native_upsert( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -909,8 +909,8 @@ class SQLBaseStore(object): return txn.execute_batch(sql, args) - def _simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one" + def simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" ): """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -924,16 +924,16 @@ class SQLBaseStore(object): statement returns no rows """ return self.runInteraction( - desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none + desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def _simple_select_one_onecol( + def simple_select_one_onecol( self, table, keyvalues, retcol, allow_none=False, - desc="_simple_select_one_onecol", + desc="simple_select_one_onecol", ): """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -945,7 +945,7 @@ class SQLBaseStore(object): """ return self.runInteraction( desc, - self._simple_select_one_onecol_txn, + self.simple_select_one_onecol_txn, table, keyvalues, retcol, @@ -953,10 +953,10 @@ class SQLBaseStore(object): ) @classmethod - def _simple_select_one_onecol_txn( + def simple_select_one_onecol_txn( cls, txn, table, keyvalues, retcol, allow_none=False ): - ret = cls._simple_select_onecol_txn( + ret = cls.simple_select_onecol_txn( txn, table=table, keyvalues=keyvalues, retcol=retcol ) @@ -969,7 +969,7 @@ class SQLBaseStore(object): raise StoreError(404, "No row found") @staticmethod - def _simple_select_onecol_txn(txn, table, keyvalues, retcol): + def simple_select_onecol_txn(txn, table, keyvalues, retcol): sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} if keyvalues: @@ -980,8 +980,8 @@ class SQLBaseStore(object): return [r[0] for r in txn] - def _simple_select_onecol( - self, table, keyvalues, retcol, desc="_simple_select_onecol" + def simple_select_onecol( + self, table, keyvalues, retcol, desc="simple_select_onecol" ): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -995,12 +995,10 @@ class SQLBaseStore(object): Deferred: Results in a list """ return self.runInteraction( - desc, self._simple_select_onecol_txn, table, keyvalues, retcol + desc, self.simple_select_onecol_txn, table, keyvalues, retcol ) - def _simple_select_list( - self, table, keyvalues, retcols, desc="_simple_select_list" - ): + def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1014,11 +1012,11 @@ class SQLBaseStore(object): defer.Deferred: resolves to list[dict[str, Any]] """ return self.runInteraction( - desc, self._simple_select_list_txn, table, keyvalues, retcols + desc, self.simple_select_list_txn, table, keyvalues, retcols ) @classmethod - def _simple_select_list_txn(cls, txn, table, keyvalues, retcols): + def simple_select_list_txn(cls, txn, table, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1044,14 +1042,14 @@ class SQLBaseStore(object): return cls.cursor_to_dict(txn) @defer.inlineCallbacks - def _simple_select_many_batch( + def simple_select_many_batch( self, table, column, iterable, retcols, keyvalues={}, - desc="_simple_select_many_batch", + desc="simple_select_many_batch", batch_size=100, ): """Executes a SELECT query on the named table, which may return zero or @@ -1080,7 +1078,7 @@ class SQLBaseStore(object): for chunk in chunks: rows = yield self.runInteraction( desc, - self._simple_select_many_txn, + self.simple_select_many_txn, table, column, chunk, @@ -1093,7 +1091,7 @@ class SQLBaseStore(object): return results @classmethod - def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1126,13 +1124,13 @@ class SQLBaseStore(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) - def _simple_update(self, table, keyvalues, updatevalues, desc): + def simple_update(self, table, keyvalues, updatevalues, desc): return self.runInteraction( - desc, self._simple_update_txn, table, keyvalues, updatevalues + desc, self.simple_update_txn, table, keyvalues, updatevalues ) @staticmethod - def _simple_update_txn(txn, table, keyvalues, updatevalues): + def simple_update_txn(txn, table, keyvalues, updatevalues): if keyvalues: where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) else: @@ -1148,8 +1146,8 @@ class SQLBaseStore(object): return txn.rowcount - def _simple_update_one( - self, table, keyvalues, updatevalues, desc="_simple_update_one" + def simple_update_one( + self, table, keyvalues, updatevalues, desc="simple_update_one" ): """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. @@ -1169,12 +1167,12 @@ class SQLBaseStore(object): the update column in the 'keyvalues' dict as well. """ return self.runInteraction( - desc, self._simple_update_one_txn, table, keyvalues, updatevalues + desc, self.simple_update_one_txn, table, keyvalues, updatevalues ) @classmethod - def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): - rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) + def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) if rowcount == 0: raise StoreError(404, "No row found (%s)" % (table,)) @@ -1182,7 +1180,7 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched (%s)" % (table,)) @staticmethod - def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, @@ -1201,7 +1199,7 @@ class SQLBaseStore(object): return dict(zip(retcols, row)) - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): + def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -1209,10 +1207,10 @@ class SQLBaseStore(object): table : string giving the table name keyvalues : dict of column names and values to select the row with """ - return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues) + return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @staticmethod - def _simple_delete_one_txn(txn, table, keyvalues): + def simple_delete_one_txn(txn, table, keyvalues): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -1231,11 +1229,11 @@ class SQLBaseStore(object): if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - def _simple_delete(self, table, keyvalues, desc): - return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues) + def simple_delete(self, table, keyvalues, desc): + return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) @staticmethod - def _simple_delete_txn(txn, table, keyvalues): + def simple_delete_txn(txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k,) for k in keyvalues), @@ -1244,13 +1242,13 @@ class SQLBaseStore(object): txn.execute(sql, list(keyvalues.values())) return txn.rowcount - def _simple_delete_many(self, table, column, iterable, keyvalues, desc): + def simple_delete_many(self, table, column, iterable, keyvalues, desc): return self.runInteraction( - desc, self._simple_delete_many_txn, table, column, iterable, keyvalues + desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) @staticmethod - def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): + def simple_delete_many_txn(txn, table, column, iterable, keyvalues): """Executes a DELETE query on the named table. Filters rows by if value of `column` is in `iterable`. @@ -1283,7 +1281,7 @@ class SQLBaseStore(object): return txn.rowcount - def _get_cache_dict( + def get_cache_dict( self, db_conn, table, entity_column, stream_column, max_value, limit=100000 ): # Fetch a mapping of room_id -> max stream position for "recent" rooms. @@ -1349,7 +1347,7 @@ class SQLBaseStore(object): # which is fine. pass - def _simple_select_list_paginate( + def simple_select_list_paginate( self, table, keyvalues, @@ -1358,7 +1356,7 @@ class SQLBaseStore(object): limit, retcols, order_direction="ASC", - desc="_simple_select_list_paginate", + desc="simple_select_list_paginate", ): """ Executes a SELECT query on the named table with start and limit, @@ -1380,7 +1378,7 @@ class SQLBaseStore(object): """ return self.runInteraction( desc, - self._simple_select_list_paginate_txn, + self.simple_select_list_paginate_txn, table, keyvalues, orderby, @@ -1391,7 +1389,7 @@ class SQLBaseStore(object): ) @classmethod - def _simple_select_list_paginate_txn( + def simple_select_list_paginate_txn( cls, txn, table, @@ -1452,9 +1450,7 @@ class SQLBaseStore(object): txn.execute(sql_count) return txn.fetchone()[0] - def _simple_search_list( - self, table, term, col, retcols, desc="_simple_search_list" - ): + def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1469,11 +1465,11 @@ class SQLBaseStore(object): """ return self.runInteraction( - desc, self._simple_search_list_txn, table, term, col, retcols + desc, self.simple_search_list_txn, table, term, col, retcols ) @classmethod - def _simple_search_list_txn(cls, txn, table, term, col, retcols): + def simple_search_list_txn(cls, txn, table, term, col, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 37d469ffd7..06955a0537 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -139,7 +139,7 @@ class BackgroundUpdateStore(SQLBaseStore): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self._simple_select_onecol( + updates = yield self.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -161,7 +161,7 @@ class BackgroundUpdateStore(SQLBaseStore): if update_name in self._background_update_queue: return False - update_exists = await self._simple_select_one_onecol( + update_exists = await self.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="1", @@ -184,7 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore): no more work to do. """ if not self._background_update_queue: - updates = yield self._simple_select_list( + updates = yield self.simple_select_list( "background_updates", keyvalues=None, retcols=("update_name", "depends_on"), @@ -226,7 +226,7 @@ class BackgroundUpdateStore(SQLBaseStore): else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = yield self._simple_select_one_onecol( + progress_json = yield self.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", @@ -413,7 +413,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [] progress_json = json.dumps(progress) - return self._simple_insert( + return self.simple_insert( "background_updates", {"update_name": update_name, "progress_json": progress_json}, ) @@ -429,7 +429,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [ name for name in self._background_update_queue if name != update_name ] - return self._simple_delete_one( + return self.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) @@ -444,7 +444,7 @@ class BackgroundUpdateStore(SQLBaseStore): progress_json = json.dumps(progress) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 474924c68f..2a5b33dda1 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -173,7 +173,7 @@ class DataStore( self._presence_on_startup = self._get_active_presence(db_conn) - presence_cache_prefill, min_presence_val = self._get_cache_dict( + presence_cache_prefill, min_presence_val = self.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", @@ -187,7 +187,7 @@ class DataStore( ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( + device_inbox_prefill, min_device_inbox_id = self.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", @@ -202,7 +202,7 @@ class DataStore( ) # The federation outbox and the local device inbox uses the same # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( + device_outbox_prefill, min_device_outbox_id = self.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", @@ -228,7 +228,7 @@ class DataStore( ) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -242,7 +242,7 @@ class DataStore( prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + _group_updates_prefill, min_group_updates_id = self.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", @@ -482,7 +482,7 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self._simple_select_list( + return self.simple_select_list( table="users", keyvalues={}, retcols=["name", "password_hash", "is_guest", "admin", "user_type"], @@ -504,7 +504,7 @@ class DataStore( """ users = yield self.runInteraction( "get_users_paginate", - self._simple_select_list_paginate_txn, + self.simple_select_list_paginate_txn, table="users", keyvalues={"is_guest": False}, orderby=order, @@ -526,7 +526,7 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self._simple_search_list( + return self.simple_search_list( table="users", term=term, col="name", diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 22093484ed..b0d22faf3f 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -67,7 +67,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_user_txn(txn): - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, @@ -78,7 +78,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, @@ -102,7 +102,7 @@ class AccountDataWorkerStore(SQLBaseStore): Returns: Deferred: A dict """ - result = yield self._simple_select_one_onecol( + result = yield self.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -127,7 +127,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_txn(txn): - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, @@ -156,7 +156,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_and_type_txn(txn): - content_json = self._simple_select_one_onecol_txn( + content_json = self.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ @@ -300,9 +300,9 @@ class AccountDataStore(AccountDataWorkerStore): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so _simple_upsert will + # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self._simple_upsert( + yield self.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -346,9 +346,9 @@ class AccountDataStore(AccountDataWorkerStore): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so _simple_upsert will retry if + # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self._simple_upsert( + yield self.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 81babf2029..6b82fd392a 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -133,7 +133,7 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves to a list of ApplicationServices, which may be empty. """ - results = yield self._simple_select_list( + results = yield self.simple_select_list( "application_services_state", dict(state=state), ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -155,7 +155,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves to ApplicationServiceState. """ - result = yield self._simple_select_one( + result = yield self.simple_select_one( "application_services_state", dict(as_id=service.id), ["state"], @@ -175,7 +175,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves when the state was set successfully. """ - return self._simple_upsert( + return self.simple_upsert( "application_services_state", dict(as_id=service.id), dict(state=state) ) @@ -249,7 +249,7 @@ class ApplicationServiceTransactionWorkerStore( ) # Set current txn_id for AS to 'txn_id' - self._simple_upsert_txn( + self.simple_upsert_txn( txn, "application_services_state", dict(as_id=service.id), @@ -257,7 +257,7 @@ class ApplicationServiceTransactionWorkerStore( ) # Delete txn - self._simple_delete_txn( + self.simple_delete_txn( txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 258c08722a..de3256049d 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -95,7 +95,7 @@ class CacheInvalidationStore(SQLBaseStore): txn.call_after(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="cache_invalidation_stream", values={ diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index cae93b0e22..66522a04b7 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -431,7 +431,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="user_ips", keyvalues={ @@ -450,7 +450,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -483,7 +483,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - res = yield self._simple_select_list( + res = yield self.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -516,7 +516,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index a23744f11c..206d39134d 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -314,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. - already_inserted = self._simple_select_one_txn( + already_inserted = self.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, @@ -326,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Add an entry for this message_id so that we know we've processed # it. - self._simple_insert_txn( + self.simple_insert_txn( txn, table="device_federation_inbox", values={ diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index a3ad23e783..727c582121 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore): Raises: StoreError: if the device is not found """ - return self._simple_select_one( + return self.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore): containing "device_id", "user_id" and "display_name" for each device. """ - devices = yield self._simple_select_list( + devices = yield self.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore): from_user_id, stream_id, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, "user_signature_stream", values={ @@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2, tree=True) def _get_cached_user_device(self, user_id, device_id): - content = yield self._simple_select_one_onecol( + content = yield self.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): - devices = yield self._simple_select_list( + devices = yield self.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self._execute( + rows = yield self.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return set(user for row in rows for user in json.loads(row[0])) @@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id, destination """ - return self._execute( + return self.execute( "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -722,7 +722,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self._simple_insert( + inserted = yield self.simple_insert( "devices", values={ "user_id": user_id, @@ -736,7 +736,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self._simple_select_one_onecol( + hidden = yield self.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -771,7 +771,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_one( + yield self.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -789,7 +789,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_many( + yield self.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -818,7 +818,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updates["display_name"] = new_display_name if not updates: return defer.succeed(None) - return self._simple_update_one( + return self.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -829,7 +829,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - yield self._simple_delete( + yield self.simple_delete( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, desc="mark_remote_user_device_list_as_unsubscribed", @@ -866,7 +866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn, user_id, device_id, content, stream_id ): if content.get("deleted"): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -874,7 +874,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -890,7 +890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -923,11 +923,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[ @@ -946,7 +946,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -995,7 +995,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [(user_id, device_id, stream_id) for device_id in device_ids], ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="device_lists_stream", values=[ @@ -1006,7 +1006,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context = get_active_span_text_map() - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py index 297966d9f4..d332f8a409 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore): Deferred: results in namedtuple with keys "room_id" and "servers" or None if no association can be found """ - room_id = yield self._simple_select_one_onecol( + room_id = yield self.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self._simple_select_onecol( + servers = yield self.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", @@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore): @cached(max_entries=5000) def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( + return self.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", @@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore): """ def alias_txn(txn): - self._simple_insert_txn( + self.simple_insert_txn( txn, "room_aliases", { @@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore): }, ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="room_alias_servers", values=[ diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 113224fd7c..df89eda337 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self._simple_update_one( + yield self.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -89,7 +89,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): } ) - yield self._simple_insert_many( + yield self.simple_insert_many( table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @@ -125,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -234,7 +234,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version (str): the version ID of the backup we're querying about """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", @@ -267,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self._simple_delete( + yield self.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -312,7 +312,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): # it isn't there. raise StoreError(404, "No row found") - result = self._simple_select_one_txn( + result = self.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, @@ -352,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): new_version = str(int(current_version) + 1) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ @@ -391,7 +391,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues["etag"] = version_etag if updatevalues: - return self._simple_update( + return self.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, @@ -420,13 +420,13 @@ class EndToEndRoomKeyStore(SQLBaseStore): else: this_version = version - self._simple_delete_txn( + self.simple_delete_txn( txn, table="e2e_room_keys", keyvalues={"user_id": user_id, "version": this_version}, ) - return self._simple_update_one_txn( + return self.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 643327b57b..08bcdc4725 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -186,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): key_id) to json string for key """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -219,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", values=[ @@ -350,7 +350,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id """ - return self._execute( + return self.execute( "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key ) @@ -367,7 +367,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): set_tag("time_now", time_now) set_tag("device_keys", device_keys) - old_key_json = self._simple_select_one_onecol_txn( + old_key_json = self.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -383,7 +383,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"Message": "Device key already stored."}) return False - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -442,12 +442,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "user_id": user_id, } ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -492,7 +492,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) - self._simple_insert_txn( + self.simple_insert_txn( txn, "devices", values={ @@ -505,7 +505,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: - self._simple_insert_txn( + self.simple_insert_txn( txn, "e2e_cross_signing_keys", values={ @@ -539,7 +539,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): user_id (str): the user who made the signatures signatures (iterable[SignatureListItem]): signatures to add """ - return self._simple_insert_many( + return self.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 90bef0cd2c..051ac7a8cb 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -126,7 +126,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns Deferred[int] """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -140,7 +140,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return max(row["depth"] for row in rows) def _get_oldest_events_in_room_txn(self, txn, room_id): - return self._simple_select_onecol_txn( + return self.simple_select_onecol_txn( txn, table="event_backward_extremities", keyvalues={"room_id": room_id}, @@ -235,7 +235,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): - return self._simple_select_onecol( + return self.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", @@ -271,7 +271,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) def _get_min_depth_interaction(self, txn, room_id): - min_depth = self._simple_select_one_onecol_txn( + min_depth = self.simple_select_one_onecol_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -383,7 +383,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas queue = PriorityQueue() for event_id in event_list: - depth = self._simple_select_one_onecol_txn( + depth = self.simple_select_one_onecol_txn( txn, table="events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -468,7 +468,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: Deferred[list[str]] """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -508,7 +508,7 @@ class EventFederationStore(EventFederationWorkerStore): if min_depth and depth >= min_depth: return - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -520,7 +520,7 @@ class EventFederationStore(EventFederationWorkerStore): For the given event, update the event edges table and forward and backward extremities tables. """ - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_edges", values=[ diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 04ce21ac66..0a37847cfd 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -441,7 +441,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _add_push_actions_to_staging_txn(txn): - # We don't use _simple_insert_many here to avoid the overhead + # We don't use simple_insert_many here to avoid the overhead # of generating lists of dicts. sql = """ @@ -472,7 +472,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ try: - res = yield self._simple_delete( + res = yield self.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -677,7 +677,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) for event, _ in events_and_contexts: - user_ids = self._simple_select_onecol_txn( + user_ids = self.simple_select_onecol_txn( txn, table="event_push_actions_staging", keyvalues={"event_id": event.event_id}, @@ -844,7 +844,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): the archiving process has caught up or not. """ - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -880,7 +880,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return caught_up def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -912,7 +912,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the # existing table. - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_push_summary", values=[ diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 79c91fe284..98ae69e996 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -432,7 +432,7 @@ class EventsStore( # event's auth chain, but its easier for now just to store them (and # it doesn't take much storage compared to storing the entire event # anyway). - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -580,12 +580,12 @@ class EventsStore( self, txn, new_forward_extremities, max_stream_order ): for room_id, new_extrem in iteritems(new_forward_extremities): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_forward_extremities", values=[ @@ -598,7 +598,7 @@ class EventsStore( # new stream_ordering to new forward extremeties in the room. # This allows us to later efficiently look up the forward extremeties # for a room before a given stream_ordering - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="stream_ordering_to_exterm", values=[ @@ -722,7 +722,7 @@ class EventsStore( # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering state_group_id = context.state_group - self._simple_insert_txn( + self.simple_insert_txn( txn, table="ex_outlier_stream", values={ @@ -794,7 +794,7 @@ class EventsStore( d.pop("redacted_because", None) return d - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_json", values=[ @@ -811,7 +811,7 @@ class EventsStore( ], ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="events", values=[ @@ -841,7 +841,7 @@ class EventsStore( # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - self._simple_update_txn( + self.simple_update_txn( txn, table="redactions", keyvalues={"redacts": event.event_id}, @@ -983,7 +983,7 @@ class EventsStore( state_values.append(vals) - self._simple_insert_many_txn(txn, table="state_events", values=state_values) + self.simple_insert_many_txn(txn, table="state_events", values=state_values) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1032,7 +1032,7 @@ class EventsStore( # invalidate the cache for the redacted event txn.call_after(self._invalidate_get_event_cache, event.redacts) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="redactions", values={ @@ -1077,9 +1077,7 @@ class EventsStore( LIMIT ? """ - rows = yield self._execute( - "_censor_redactions_fetch", None, sql, before_ts, 100 - ) + rows = yield self.execute("_censor_redactions_fetch", None, sql, before_ts, 100) updates = [] @@ -1111,7 +1109,7 @@ class EventsStore( if pruned_json: self._censor_event_txn(txn, event_id, pruned_json) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="redactions", keyvalues={"event_id": redaction_id}, @@ -1129,7 +1127,7 @@ class EventsStore( event_id (str): The ID of the event to censor. pruned_json (str): The pruned JSON """ - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="event_json", keyvalues={"event_id": event_id}, @@ -1780,7 +1778,7 @@ class EventsStore( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) - rows = self._simple_select_many_txn( + rows = self.simple_select_many_txn( txn, table="state_group_edges", column="prev_state_group", @@ -1807,15 +1805,15 @@ class EventsStore( curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": sg} ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1852,7 +1850,7 @@ class EventsStore( state group. """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, @@ -1882,7 +1880,7 @@ class EventsStore( # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn, table="state_groups_state", column="state_group", @@ -1893,7 +1891,7 @@ class EventsStore( # ... and the state group edges logger.info("[purge] removing %s from state_group_edges", room_id) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn, table="state_group_edges", column="state_group", @@ -1904,7 +1902,7 @@ class EventsStore( # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn, table="state_groups", column="id", @@ -1921,7 +1919,7 @@ class EventsStore( @cachedInlineCallbacks(max_entries=5000) def _get_event_ordering(self, event_id): - res = yield self._simple_select_one( + res = yield self.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, @@ -1962,7 +1960,7 @@ class EventsStore( room_id (str): The ID of the room the event was sent to. topological_ordering (int): The position of the event in the room's topology. """ - return self._simple_insert_many_txn( + return self.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -1984,7 +1982,7 @@ class EventsStore( event_id (str): The event ID the expiry timestamp is associated with. expiry_ts (int): The timestamp at which to expire (delete) the event. """ - return self._simple_insert_txn( + return self.simple_insert_txn( txn=txn, table="event_expiry", values={"event_id": event_id, "expiry_ts": expiry_ts}, @@ -2043,7 +2041,7 @@ class EventsStore( txn (LoggingTransaction): The transaction to use to perform the deletion. event_id (str): The event ID to delete the associated expiry timestamp of. """ - return self._simple_delete_txn( + return self.simple_delete_txn( txn=txn, table="event_expiry", keyvalues={"event_id": event_id} ) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index aa87f9abc5..37dfc8c871 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -189,7 +189,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self._simple_select_many_txn( + ev_rows = self.simple_select_many_txn( txn, table="event_json", column="event_id", @@ -366,7 +366,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): to_delete.intersection_update(original_set) - deleted = self._simple_delete_many_txn( + deleted = self.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -382,7 +382,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if deleted: # We now need to invalidate the caches of these rooms - rows = self._simple_select_many_txn( + rows = self.simple_select_many_txn( txn, table="events", column="event_id", @@ -396,7 +396,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", @@ -533,7 +533,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): try: event_json = json.loads(event_json_raw) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn=txn, table="event_labels", values=[ diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index e782e8f481..ec4af29299 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -78,7 +78,7 @@ class EventsWorkerStore(SQLBaseStore): Deferred[int|None]: Timestamp in milliseconds, or None for events that were persisted before received_ts was implemented. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", @@ -452,7 +452,7 @@ class EventsWorkerStore(SQLBaseStore): event_id for events, _ in event_list for event_id in events ) - row_dict = self._new_transaction( + row_dict = self.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) @@ -745,7 +745,7 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py index f05ace299a..17ef7b9354 100644 --- a/synapse/storage/data_stores/main/filtering.py +++ b/synapse/storage/data_stores/main/filtering.py @@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self._simple_select_one_onecol( + def_json = yield self.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 5ded539af8..9e1d12bcb7 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore): * "invite" * "open" """ - return self._simple_update_one( + return self.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore): ) def get_group(self, group_id): - return self._simple_select_one( + return self.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), @@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore): def get_invited_users_in_group(self, group_id): # TODO: Pagination - return self._simple_select_onecol( + return self.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.simple_select_list( table="group_rooms", keyvalues=keyvalues, retcols=("room_id", "is_public"), @@ -180,7 +180,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the room first. Otherwise, the room gets added to the end. """ - room_in_group = self._simple_select_one_onecol_txn( + room_in_group = self.simple_select_one_onecol_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -193,7 +193,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID else: - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.simple_select_one_onecol_txn( txn, table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -204,7 +204,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Category doesn't exist") # TODO: Check category is part of summary already - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.simple_select_one_onecol_txn( txn, table="group_summary_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -224,7 +224,7 @@ class GroupServerStore(SQLBaseStore): (group_id, category_id, group_id, category_id), ) - existing = self._simple_select_one_txn( + existing = self.simple_select_one_txn( txn, table="group_summary_rooms", keyvalues={ @@ -257,7 +257,7 @@ class GroupServerStore(SQLBaseStore): to_update["room_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.simple_update_txn( txn, table="group_summary_rooms", keyvalues={ @@ -271,7 +271,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_summary_rooms", values={ @@ -287,7 +287,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self._simple_delete( + return self.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -299,7 +299,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_categories(self, group_id): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -316,7 +316,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_category(self, group_id, category_id): - category = yield self._simple_select_one( + category = yield self.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -343,7 +343,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -352,7 +352,7 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_category(self, group_id, category_id): - return self._simple_delete( + return self.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -360,7 +360,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_roles(self, group_id): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -377,7 +377,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_role(self, group_id, role_id): - role = yield self._simple_select_one( + role = yield self.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -404,7 +404,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -413,7 +413,7 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_role(self, group_id, role_id): - return self._simple_delete( + return self.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", @@ -444,7 +444,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the user first. Otherwise, the user gets added to the end. """ - user_in_group = self._simple_select_one_onecol_txn( + user_in_group = self.simple_select_one_onecol_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -457,7 +457,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID else: - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.simple_select_one_onecol_txn( txn, table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -468,7 +468,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Role doesn't exist") # TODO: Check role is part of the summary already - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.simple_select_one_onecol_txn( txn, table="group_summary_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -488,7 +488,7 @@ class GroupServerStore(SQLBaseStore): (group_id, role_id, group_id, role_id), ) - existing = self._simple_select_one_txn( + existing = self.simple_select_one_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, @@ -517,7 +517,7 @@ class GroupServerStore(SQLBaseStore): to_update["user_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.simple_update_txn( txn, table="group_summary_users", keyvalues={ @@ -531,7 +531,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_summary_users", values={ @@ -547,7 +547,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID - return self._simple_delete( + return self.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -561,7 +561,7 @@ class GroupServerStore(SQLBaseStore): Deferred[list[str]]: A twisted.Deferred containing a list of group ids containing this room """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -630,7 +630,7 @@ class GroupServerStore(SQLBaseStore): ) def is_user_in_group(self, user_id, group_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -639,7 +639,7 @@ class GroupServerStore(SQLBaseStore): ).addCallback(lambda r: bool(r)) def is_user_admin_in_group(self, group_id, user_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -650,7 +650,7 @@ class GroupServerStore(SQLBaseStore): def add_group_invite(self, group_id, user_id): """Record that the group server has invited a user """ - return self._simple_insert( + return self.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -659,7 +659,7 @@ class GroupServerStore(SQLBaseStore): def is_user_invited_to_local_group(self, group_id, user_id): """Has the group server invited a user? """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -682,7 +682,7 @@ class GroupServerStore(SQLBaseStore): """ def _get_users_membership_in_group_txn(txn): - row = self._simple_select_one_txn( + row = self.simple_select_one_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -697,7 +697,7 @@ class GroupServerStore(SQLBaseStore): "is_privileged": row["is_admin"], } - row = self._simple_select_one_onecol_txn( + row = self.simple_select_one_onecol_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -738,7 +738,7 @@ class GroupServerStore(SQLBaseStore): """ def _add_user_to_group_txn(txn): - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_users", values={ @@ -749,14 +749,14 @@ class GroupServerStore(SQLBaseStore): }, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -766,7 +766,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -781,27 +781,27 @@ class GroupServerStore(SQLBaseStore): def remove_user_from_group(self, group_id, user_id): def _remove_user_from_group_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -812,14 +812,14 @@ class GroupServerStore(SQLBaseStore): ) def add_room_to_group(self, group_id, room_id, is_public): - return self._simple_insert( + return self.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self._simple_update( + return self.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -828,13 +828,13 @@ class GroupServerStore(SQLBaseStore): def remove_room_from_group(self, group_id, room_id): def _remove_room_from_group_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_summary_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -847,7 +847,7 @@ class GroupServerStore(SQLBaseStore): def get_publicised_groups_for_user(self, user_id): """Get all groups a user is publicising """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -857,7 +857,7 @@ class GroupServerStore(SQLBaseStore): def update_group_publicity(self, group_id, user_id, publicise): """Update whether the user is publicising their membership of the group """ - return self._simple_update_one( + return self.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -893,12 +893,12 @@ class GroupServerStore(SQLBaseStore): def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? - self._simple_delete_txn( + self.simple_delete_txn( txn, table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="local_group_membership", values={ @@ -911,7 +911,7 @@ class GroupServerStore(SQLBaseStore): }, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="local_group_updates", values={ @@ -930,7 +930,7 @@ class GroupServerStore(SQLBaseStore): if membership == "join": if local_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -940,7 +940,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -951,12 +951,12 @@ class GroupServerStore(SQLBaseStore): }, ) else: - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -976,7 +976,7 @@ class GroupServerStore(SQLBaseStore): def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description ): - yield self._simple_insert( + yield self.simple_insert( table="groups", values={ "group_id": group_id, @@ -991,7 +991,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def update_group_profile(self, group_id, profile): - yield self._simple_update_one( + yield self.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, @@ -1017,7 +1017,7 @@ class GroupServerStore(SQLBaseStore): def update_attestation_renewal(self, group_id, user_id, attestation): """Update an attestation that we have renewed """ - return self._simple_update_one( + return self.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, @@ -1027,7 +1027,7 @@ class GroupServerStore(SQLBaseStore): def update_remote_attestion(self, group_id, user_id, attestation): """Update an attestation that a remote has renewed """ - return self._simple_update_one( + return self.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ @@ -1046,7 +1046,7 @@ class GroupServerStore(SQLBaseStore): group_id (str) user_id (str) """ - return self._simple_delete( + return self.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", @@ -1057,7 +1057,7 @@ class GroupServerStore(SQLBaseStore): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self._simple_select_one( + row = yield self.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -1072,7 +1072,7 @@ class GroupServerStore(SQLBaseStore): return None def get_joined_groups(self, user_id): - return self._simple_select_onecol( + return self.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -1188,7 +1188,7 @@ class GroupServerStore(SQLBaseStore): ] for table in tables: - self._simple_delete_txn( + self.simple_delete_txn( txn, table=table, keyvalues={"group_id": group_id} ) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index ebc7db3ed6..c7150432b3 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -129,7 +129,7 @@ class KeyStore(SQLBaseStore): return self.runInteraction( "store_server_verify_keys", - self._simple_upsert_many_txn, + self.simple_upsert_many_txn, table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore): ts_valid_until_ms (int): The time when this json stops being valid. key_json (bytes): The encoded JSON. """ - return self._simple_upsert( + return self.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, @@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore): keyvalues["key_id"] = key_id if from_server is not None: keyvalues["from_server"] = from_server - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "server_keys_json", keyvalues=keyvalues, diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 0f2887bdce..0cb9446f96 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -39,7 +39,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): Returns: None if the media_id doesn't exist. """ - return self._simple_select_one( + return self.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -64,7 +64,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id, url_cache=None, ): - return self._simple_insert( + return self.simple_insert( "local_media_repository", { "media_id": media_id, @@ -129,7 +129,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self._simple_insert( + return self.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -144,7 +144,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_local_media_thumbnails(self, media_id): - return self._simple_select_list( + return self.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -166,7 +166,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -180,7 +180,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_cached_remote_media(self, origin, media_id): - return self._simple_select_one( + return self.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -205,7 +205,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name, filesystem_id, ): - return self._simple_insert( + return self.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -253,7 +253,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): return self.runInteraction("update_cached_last_access_time", update_cache_txn) def get_remote_media_thumbnails(self, origin, media_id): - return self._simple_select_list( + return self.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -278,7 +278,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, @@ -300,18 +300,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE last_access_ts < ?" ) - return self._execute( + return self.execute( "get_remote_media_before", self.cursor_to_dict, sql, before_ts ) def delete_remote_media(self, media_origin, media_id): def delete_remote_media_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, "remote_media_cache", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, "remote_media_cache_thumbnails", keyvalues={"media_origin": media_origin, "media_id": media_id}, diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index b41c3d317a..b8fc28f97b 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -32,7 +32,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number - self._new_transaction( + self.new_transaction( dbconn, "initialise_mau_threepids", [], @@ -261,7 +261,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self._simple_upsert_txn( + is_insert = self.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -281,7 +281,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py index 79b40044d9..650e49750e 100644 --- a/synapse/storage/data_stores/main/openid.py +++ b/synapse/storage/data_stores/main/openid.py @@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self._simple_insert( + return self.simple_insert( table="open_id_tokens", values={ "token": token, diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index 523ed6575e..a5e121efd1 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore): txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) # Actually insert new rows - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="presence_stream", values=[ @@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore): inlineCallbacks=True, ) def get_presence_for_users(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, @@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore): return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_insert( + return self.simple_insert( table="presence_allow_inbound", values={ "observed_user_id": observed_localpart, @@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore): ) def disallow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_delete_one( + return self.simple_delete_one( table="presence_allow_inbound", keyvalues={ "observed_user_id": observed_localpart, diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index e4e8a1c1d6..c8b5b60301 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_profileinfo(self, user_localpart): try: - profile = yield self._simple_select_one( + profile = yield self.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_displayname(self, user_localpart): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", @@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_avatar_url(self, user_localpart): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", @@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_from_remote_profile_cache(self, user_id): - return self._simple_select_one( + return self.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore): ) def create_profile(self, user_localpart): - return self._simple_insert( + return self.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( + return self.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, @@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self._simple_update_one( + return self.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore): This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self._simple_upsert( + return self.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore): ) def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self._simple_update( + return self.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore): """ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self._simple_delete( + yield self.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index b520062d84..75bd499bcd 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -75,7 +75,7 @@ class PushRulesWorkerStore( def __init__(self, db_conn, hs): super(PushRulesWorkerStore, self).__init__(db_conn, hs) - push_rules_prefill, push_rules_id = self._get_cache_dict( + push_rules_prefill, push_rules_id = self.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", @@ -100,7 +100,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -124,7 +124,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): - results = yield self._simple_select_list( + results = yield self.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -162,7 +162,7 @@ class PushRulesWorkerStore( results = {user_id: [] for user_id in user_ids} - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -320,7 +320,7 @@ class PushRulesWorkerStore( results = {user_id: {} for user_id in user_ids} - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, @@ -395,7 +395,7 @@ class PushRuleStore(PushRulesWorkerStore): relative_to_rule = before or after - res = self._simple_select_one_txn( + res = self.simple_select_one_txn( txn, table="push_rules", keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, @@ -499,7 +499,7 @@ class PushRuleStore(PushRulesWorkerStore): actions_json, update_stream=True, ): - """Specialised version of _simple_upsert_txn that picks a push_rule_id + """Specialised version of simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes that the "push_rules" table is locked""" @@ -518,7 +518,7 @@ class PushRuleStore(PushRulesWorkerStore): # We didn't update a row with the given rule_id so insert one push_rule_id = self._push_rule_id_gen.get_next() - self._simple_insert_txn( + self.simple_insert_txn( txn, table="push_rules", values={ @@ -561,7 +561,7 @@ class PushRuleStore(PushRulesWorkerStore): """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self._simple_delete_one_txn( + self.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -596,7 +596,7 @@ class PushRuleStore(PushRulesWorkerStore): self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled ): new_id = self._push_rules_enable_id_gen.get_next() - self._simple_upsert_txn( + self.simple_upsert_txn( txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}, @@ -636,7 +636,7 @@ class PushRuleStore(PushRulesWorkerStore): update_stream=False, ) else: - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}, @@ -675,7 +675,7 @@ class PushRuleStore(PushRulesWorkerStore): if data is not None: values.update(data) - self._simple_insert_txn(txn, "push_rules_stream", values=values) + self.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index d76861cdc0..d5a169872b 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_has_pusher(self, user_id): - ret = yield self._simple_select_one_onecol( + ret = yield self.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_pushers_by(self, keyvalues): - ret = yield self._simple_select_list( + ret = yield self.simple_select_list( "pushers", keyvalues, [ @@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_if_users_have_pushers(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -229,8 +229,8 @@ class PusherStore(PusherWorkerStore): ): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so _simple_upsert will retry - yield self._simple_upsert( + # (app_id, pushkey, user_name) so simple_upsert will retry + yield self.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore): txn, self.get_if_user_has_pusher, (user_id,) ) - self._simple_delete_one_txn( + self.simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore): # it's possible for us to end up with duplicate rows for # (app_id, pushkey, user_id) at different stream_ids, but that # doesn't really matter. - self._simple_insert_txn( + self.simple_insert_txn( txn, table="deleted_pushers", values={ @@ -296,7 +296,7 @@ class PusherStore(PusherWorkerStore): def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ): - yield self._simple_update_one( + yield self.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, @@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore): Returns: Deferred[bool]: True if the pusher still exists; False if it has been deleted. """ - updated = yield self._simple_update( + updated = yield self.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self._simple_update( + yield self.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, @@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def get_throttle_params_by_room(self, pusher_id): - res = yield self._simple_select_list( + res = yield self.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -361,8 +361,8 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so _simple_upsert will retry - yield self._simple_upsert( + # (pusher, room_id) so simple_upsert will retry + yield self.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 8b17334ff4..380f388e30 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): - return self._simple_select_list( + return self.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), @@ -70,7 +70,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -84,7 +84,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -335,7 +335,7 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ - res = self._simple_select_one_txn( + res = self.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], @@ -388,7 +388,7 @@ class ReceiptsStore(ReceiptsWorkerStore): (user_id, room_id, receipt_type), ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="receipts_linearized", keyvalues={ @@ -398,7 +398,7 @@ class ReceiptsStore(ReceiptsWorkerStore): }, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="receipts_linearized", values={ @@ -514,7 +514,7 @@ class ReceiptsStore(ReceiptsWorkerStore): self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="receipts_graph", keyvalues={ @@ -523,7 +523,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "user_id": user_id, }, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="receipts_graph", values={ diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 653c9318cb..debc6706f5 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -45,7 +45,7 @@ class RegistrationWorkerStore(SQLBaseStore): @cached() def get_user_by_id(self, user_id): - return self._simple_select_one( + return self.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -109,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore): otherwise int representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", @@ -137,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ def set_account_validity_for_user_txn(txn): - self._simple_update_txn( + self.simple_update_txn( txn=txn, table="account_validity", keyvalues={"user_id": user_id}, @@ -167,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore): Raises: StoreError: The provided token is already set for another user. """ - yield self._simple_update_one( + yield self.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, @@ -184,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The ID of the user to which the token belongs. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", @@ -203,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The renewal token associated with this user ID. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", @@ -250,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore): email_sent (bool): Flag which indicates whether a renewal email has been sent to this user. """ - yield self._simple_update_one( + yield self.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, @@ -265,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore): Args: user_id (str): ID of the user to remove from the account validity table. """ - yield self._simple_delete_one( + yield self.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", @@ -281,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", @@ -299,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self._simple_update_one( + return self.simple_update_one( table="users", keyvalues={"name": user.to_string()}, updatevalues={"admin": 1 if admin else 0}, @@ -351,7 +351,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res def is_real_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -361,7 +361,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res is None def is_support_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -394,7 +394,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: the mxid of the user, or None if they are not known """ - return await self._simple_select_one_onecol( + return await self.simple_select_one_onecol( table="user_external_ids", keyvalues={"auth_provider": auth_provider, "external_id": external_id}, retcol="user_id", @@ -536,7 +536,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: user id or None if no user id/threepid mapping exists """ - ret = self._simple_select_one_txn( + ret = self.simple_select_one_txn( txn, "user_threepids", {"medium": medium, "address": address}, @@ -549,7 +549,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self._simple_upsert( + yield self.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, @@ -557,7 +557,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_get_threepids(self, user_id): - ret = yield self._simple_select_list( + ret = yield self.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], @@ -566,7 +566,7 @@ class RegistrationWorkerStore(SQLBaseStore): return ret def user_delete_threepid(self, user_id, medium, address): - return self._simple_delete( + return self.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", @@ -579,7 +579,7 @@ class RegistrationWorkerStore(SQLBaseStore): user_id: The user id to delete all threepids of """ - return self._simple_delete( + return self.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -601,7 +601,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ # We need to use an upsert, in case they user had already bound the # threepid - return self._simple_upsert( + return self.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -627,7 +627,7 @@ class RegistrationWorkerStore(SQLBaseStore): medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self._simple_select_list( + return self.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -648,7 +648,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred """ - return self._simple_delete( + return self.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -671,7 +671,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[list[str]]: Resolves to a list of identity servers """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", @@ -689,7 +689,7 @@ class RegistrationWorkerStore(SQLBaseStore): defer.Deferred(bool): The requested value. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -776,12 +776,12 @@ class RegistrationWorkerStore(SQLBaseStore): """ def delete_threepid_session_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -961,7 +961,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ next_id = self._access_tokens_id_gen.get_next() - yield self._simple_insert( + yield self.simple_insert( "access_tokens", { "id": next_id, @@ -1037,7 +1037,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Ensure that the guest user actually exists # ``allow_none=False`` makes this raise an exception # if the row isn't in the database. - self._simple_select_one_txn( + self.simple_select_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1045,7 +1045,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): allow_none=False, ) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1059,7 +1059,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) else: - self._simple_insert_txn( + self.simple_insert_txn( txn, "users", values={ @@ -1114,7 +1114,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self._simple_insert( + return self.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1132,7 +1132,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def user_set_password_hash_txn(txn): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -1152,7 +1152,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1176,7 +1176,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1234,7 +1234,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def delete_access_token(self, access_token): def f(txn): - self._simple_delete_one_txn( + self.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -1246,7 +1246,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): @cachedInlineCallbacks() def is_guest(self, user_id): - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1261,7 +1261,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self._simple_insert( + return self.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", @@ -1274,7 +1274,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self._simple_delete( + return self.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", @@ -1285,7 +1285,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1315,7 +1315,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): - row = self._simple_select_one_txn( + row = self.simple_select_one_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1333,7 +1333,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): 400, "This client_secret does not match the provided session_id" ) - row = self._simple_select_one_txn( + row = self.simple_select_one_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id, "token": token}, @@ -1358,7 +1358,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Looks good. Validate the session - self._simple_update_txn( + self.simple_update_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1401,7 +1401,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): if validated_at: insertion_values["validated_at"] = validated_at - return self._simple_upsert( + return self.simple_upsert( table="threepid_validation_session", keyvalues={"session_id": session_id}, values={"last_send_attempt": send_attempt}, @@ -1439,7 +1439,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def start_or_continue_validation_session_txn(txn): # Create or update a validation session - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1452,7 +1452,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Create a new validation token with this session ID - self._simple_insert_txn( + self.simple_insert_txn( txn, table="threepid_validation_token", values={ @@ -1501,7 +1501,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self._simple_update_one_txn( + self.simple_update_one_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -1560,7 +1560,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expiration_ts, ) - self._simple_upsert_txn( + self.simple_upsert_txn( txn, "account_validity", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py index 7d5de0ea2e..f81f9279a1 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): def _store_rejections_txn(self, txn, event_id, reason): - self._simple_insert_txn( + self.simple_insert_txn( txn, table="rejections", values={ @@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore): ) def get_rejection_reason(self, event_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py index 858f65582b..aa5e10538b 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/data_stores/main/relations.py @@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore): aggregation_key = relation.get("key") - self._simple_insert_txn( + self.simple_insert_txn( txn, table="event_relations", values={ @@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore): redacted_event_id (str): The event that was redacted. """ - self._simple_delete_txn( + self.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index b7f9024811..8f9b6365c1 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -53,7 +53,7 @@ class RoomWorkerStore(SQLBaseStore): Returns: A dict containing the room information, or None if the room is unknown. """ - return self._simple_select_one( + return self.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -62,7 +62,7 @@ class RoomWorkerStore(SQLBaseStore): ) def get_public_room_ids(self): - return self._simple_select_onecol( + return self.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", @@ -266,7 +266,7 @@ class RoomWorkerStore(SQLBaseStore): @cached(max_entries=10000) def is_room_blocked(self, room_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", @@ -287,7 +287,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self._simple_select_one( + row = yield self.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -407,7 +407,7 @@ class RoomStore(RoomWorkerStore, SearchStore): ev = json.loads(row["json"]) retention_policy = json.dumps(ev["content"]) - self._simple_insert_txn( + self.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -453,7 +453,7 @@ class RoomStore(RoomWorkerStore, SearchStore): try: def store_room_txn(txn, next_id): - self._simple_insert_txn( + self.simple_insert_txn( txn, "rooms", { @@ -463,7 +463,7 @@ class RoomStore(RoomWorkerStore, SearchStore): }, ) if is_public: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -482,14 +482,14 @@ class RoomStore(RoomWorkerStore, SearchStore): @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="rooms", keyvalues={"room_id": room_id}, updatevalues={"is_public": is_public}, ) - entries = self._simple_select_list_txn( + entries = self.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -507,7 +507,7 @@ class RoomStore(RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -547,7 +547,7 @@ class RoomStore(RoomWorkerStore, SearchStore): def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="appservice_room_list", values={ @@ -560,7 +560,7 @@ class RoomStore(RoomWorkerStore, SearchStore): # We've already inserted, nothing to do. return else: - self._simple_delete_txn( + self.simple_delete_txn( txn, table="appservice_room_list", keyvalues={ @@ -570,7 +570,7 @@ class RoomStore(RoomWorkerStore, SearchStore): }, ) - entries = self._simple_select_list_txn( + entries = self.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -588,7 +588,7 @@ class RoomStore(RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -652,7 +652,7 @@ class RoomStore(RoomWorkerStore, SearchStore): # Ignore the event if one of the value isn't an integer. return - self._simple_insert_txn( + self.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -671,7 +671,7 @@ class RoomStore(RoomWorkerStore, SearchStore): self, room_id, event_id, user_id, reason, content, received_ts ): next_id = self._event_reports_id_gen.get_next() - return self._simple_insert( + return self.simple_insert( table="event_reports", values={ "id": next_id, @@ -717,7 +717,7 @@ class RoomStore(RoomWorkerStore, SearchStore): Returns: Deferred """ - yield self._simple_upsert( + yield self.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index b314d75941..fe2428a281 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -128,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership column is up to date """ - pending_update = self._simple_select_one_txn( + pending_update = self.simple_select_one_txn( txn, table="background_updates", keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, @@ -603,7 +603,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -643,7 +643,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) + rows = yield self.execute("is_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -683,7 +683,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + rows = yield self.execute("was_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -805,7 +805,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred[set[str]]: Set of room IDs. """ - room_ids = yield self._simple_select_onecol( + room_ids = yield self.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -820,7 +820,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Get user_id and membership of a set of event IDs. """ - return self._simple_select_many_batch( + return self.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -990,7 +990,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="room_memberships", values=[ @@ -1028,7 +1028,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): is_mine = self.hs.is_mine_id(event.state_key) if is_new_state and is_mine: if event.membership == Membership.INVITE: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="local_invites", values={ diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index d1d7c6863d..f735cf095c 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -441,7 +441,7 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args) + results = yield self.execute("search_msgs", self.cursor_to_dict, sql, *args) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -455,7 +455,7 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self._execute( + count_results = yield self.execute( "search_rooms_count", self.cursor_to_dict, count_sql, *count_args ) @@ -586,7 +586,7 @@ class SearchStore(SearchBackgroundUpdateStore): args.append(limit) - results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args) + results = yield self.execute("search_rooms", self.cursor_to_dict, sql, *args) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -600,7 +600,7 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self._execute( + count_results = yield self.execute( "search_rooms_count", self.cursor_to_dict, count_sql, *count_args ) diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py index 556191b76f..f3da29ce14 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore): } ) - self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) + self.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 6a90daea31..2b33ec1a35 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): count = 0 while next_group: - next_group = self._simple_select_one_onecol_txn( + next_group = self.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): ): break - next_group = self._simple_select_one_onecol_txn( + next_group = self.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -431,7 +431,7 @@ class StateGroupWorkerStore( """ def _get_state_group_delta_txn(txn): - prev_group = self._simple_select_one_onecol_txn( + prev_group = self.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, @@ -442,7 +442,7 @@ class StateGroupWorkerStore( if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self._simple_select_list_txn( + delta_ids = self.simple_select_list_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, @@ -644,7 +644,7 @@ class StateGroupWorkerStore( @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", @@ -661,7 +661,7 @@ class StateGroupWorkerStore( def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, @@ -902,7 +902,7 @@ class StateGroupWorkerStore( state_group = self.database_engine.get_next_state_group_id(txn) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="state_groups", values={"id": state_group, "room_id": room_id, "event_id": event_id}, @@ -911,7 +911,7 @@ class StateGroupWorkerStore( # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if prev_group: - is_in_db = self._simple_select_one_onecol_txn( + is_in_db = self.simple_select_one_onecol_txn( txn, table="state_groups", keyvalues={"id": prev_group}, @@ -926,13 +926,13 @@ class StateGroupWorkerStore( potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="state_group_edges", values={"state_group": state_group, "prev_state_group": prev_group}, ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -947,7 +947,7 @@ class StateGroupWorkerStore( ], ) else: - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1007,7 +1007,7 @@ class StateGroupWorkerStore( referenced. """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, @@ -1065,7 +1065,7 @@ class StateBackgroundUpdateStore( batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self._execute( + rows = yield self.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -1135,13 +1135,13 @@ class StateBackgroundUpdateStore( if prev_state.get(key, None) != value } - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="state_group_edges", values={ @@ -1150,13 +1150,13 @@ class StateBackgroundUpdateStore( }, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1263,7 +1263,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): state_groups[event.event_id] = context.state_group - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_to_state_groups", values=[ diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py index 28f33ec18f..03b908026b 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py @@ -105,7 +105,7 @@ class StateDeltasStore(SQLBaseStore): ) def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self._simple_select_one_onecol_txn( + return self.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 45b3de7d56..3aeba859fd 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore): """ Returns the stats processor positions. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -215,7 +215,7 @@ class StatsStore(StateDeltasStore): if field and "\0" in field: fields[col] = None - return self._simple_upsert( + return self.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, @@ -257,7 +257,7 @@ class StatsStore(StateDeltasStore): ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] ) - slice_list = self._simple_select_list_paginate_txn( + slice_list = self.simple_select_list_paginate_txn( txn, table + "_historical", {id_col: stats_id}, @@ -282,7 +282,7 @@ class StatsStore(StateDeltasStore): "name", "topic", "canonical_alias", "avatar", "join_rules", "history_visibility" """ - return self._simple_select_one( + return self.simple_select_one( "room_stats_state", {"room_id": room_id}, retcols=( @@ -308,7 +308,7 @@ class StatsStore(StateDeltasStore): """ table, id_col = TYPE_TO_TABLE[stats_type] - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", @@ -344,7 +344,7 @@ class StatsStore(StateDeltasStore): complete_with_stream_id=stream_id, ) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="stats_incremental_position", keyvalues={}, @@ -517,17 +517,17 @@ class StatsStore(StateDeltasStore): else: self.database_engine.lock_table(txn, table) retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self._simple_select_one_txn( + current_row = self.simple_select_one_txn( txn, table, keyvalues, retcols, allow_none=True ) if current_row is None: merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self._simple_insert_txn(txn, table, merged_dict) + self.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): current_row[key] += val current_row.update(absolutes) - self._simple_update_one_txn(txn, table, keyvalues, current_row) + self.simple_update_one_txn(txn, table, keyvalues, current_row) def _upsert_copy_from_table_with_additive_relatives_txn( self, @@ -614,11 +614,11 @@ class StatsStore(StateDeltasStore): txn.execute(sql, qargs) else: self.database_engine.lock_table(txn, into_table) - src_row = self._simple_select_one_txn( + src_row = self.simple_select_one_txn( txn, src_table, keyvalues, copy_columns ) all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self._simple_select_one_txn( + dest_current_row = self.simple_select_one_txn( txn, into_table, keyvalues=all_dest_keyvalues, @@ -634,11 +634,11 @@ class StatsStore(StateDeltasStore): **src_row, **additive_relatives, } - self._simple_insert_txn(txn, into_table, merged_dict) + self.simple_insert_txn(txn, into_table, merged_dict) else: for (key, val) in additive_relatives.items(): src_row[key] = dest_current_row[key] + val - self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + self.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): """Fetches the counts of events in the given range of stream IDs. @@ -735,7 +735,7 @@ class StatsStore(StateDeltasStore): def _fetch_current_state_stats(txn): pos = self.get_room_max_stream_ordering() - rows = self._simple_select_many_txn( + rows = self.simple_select_many_txn( txn, table="current_state_events", column="type", diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 21a410afd0..60487c4559 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -255,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() - event_cache_prefill, min_event_val = self._get_cache_dict( + event_cache_prefill, min_event_val = self.get_cache_dict( db_conn, "events", entity_column="room_id", @@ -576,7 +576,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "s%d" stream token. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) @@ -589,7 +589,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "t%d-%d" topological token. """ - return self._simple_select_one( + return self.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), @@ -613,7 +613,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self._execute( + return self.execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) @@ -709,7 +709,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self._simple_select_one_txn( + results = self.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -797,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events def get_federation_out_pos(self, typ): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, @@ -805,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) def update_federation_out_pos(self, typ, stream_id): - return self._simple_update_one( + return self.simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index aa24339717..85012403be 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore): tag strings to tag content. """ - deferred = self._simple_select_list( + deferred = self.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) @@ -153,7 +153,7 @@ class TagsWorkerStore(AccountDataWorkerStore): Returns: A deferred list of string tags. """ - return self._simple_select_list( + return self.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), @@ -178,7 +178,7 @@ class TagsStore(TagsWorkerStore): content_json = json.dumps(content) def add_tag_txn(txn, next_id): - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index 01b1be5e14..c162f3ea16 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -85,7 +85,7 @@ class TransactionStore(SQLBaseStore): ) def _get_received_txn_response(self, txn, transaction_id, origin): - result = self._simple_select_one_txn( + result = self.simple_select_one_txn( txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, @@ -119,7 +119,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ - return self._simple_insert( + return self.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, @@ -160,7 +160,7 @@ class TransactionStore(SQLBaseStore): return result def _get_destination_retry_timings(self, txn, destination): - result = self._simple_select_one_txn( + result = self.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -227,7 +227,7 @@ class TransactionStore(SQLBaseStore): # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. - prev_row = self._simple_select_one_txn( + prev_row = self.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -236,7 +236,7 @@ class TransactionStore(SQLBaseStore): ) if not prev_row: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="destinations", values={ @@ -247,7 +247,7 @@ class TransactionStore(SQLBaseStore): }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "destinations", keyvalues={"destination": destination}, diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 652abe0e6a..1a85aabbfb 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -85,7 +85,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ txn.execute(sql) rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + self.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms # If search all users is on, get all the users we want to add. @@ -100,13 +100,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("SELECT name FROM users") users = [{"user_id": x[0]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) new_pos = yield self.get_max_stream_id_in_current_state_deltas() yield self.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + yield self.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) yield self._end_background_update("populate_user_directory_createtables") return 1 @@ -116,7 +116,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ Update the user directory stream position, then clean up the old tables. """ - position = yield self._simple_select_one_onecol( + position = yield self.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) yield self.update_user_directory_stream_pos(position) @@ -243,7 +243,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore to_insert.clear() # We've finished a room. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + yield self.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) # Update the remaining counter. progress["remaining"] -= 1 yield self.runInteraction( @@ -312,7 +312,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) # We've finished processing a user. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + yield self.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) # Update the remaining counter. progress["remaining"] -= 1 yield self.runInteraction( @@ -361,7 +361,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _update_profile_in_user_dir_txn(txn): - new_entry = self._simple_upsert_txn( + new_entry = self.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -435,7 +435,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id}, @@ -462,7 +462,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _add_users_who_share_room_txn(txn): - self._simple_upsert_many_txn( + self.simple_upsert_many_txn( txn, table="users_who_share_private_rooms", key_names=["user_id", "other_user_id", "room_id"], @@ -489,7 +489,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore def _add_users_in_public_rooms_txn(txn): - self._simple_upsert_many_txn( + self.simple_upsert_many_txn( txn, table="users_in_public_rooms", key_names=["user_id", "room_id"], @@ -519,7 +519,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore @cached() def get_user_in_directory(self, user_id): - return self._simple_select_one( + return self.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -528,7 +528,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) def update_user_directory_stream_pos(self, stream_id): - return self._simple_update_one( + return self.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, @@ -547,21 +547,21 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id}, @@ -575,14 +575,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self._simple_select_onecol( + user_ids_share_pub = yield self.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self._simple_select_onecol( + user_ids_share_priv = yield self.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -605,17 +605,17 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """ def _remove_user_who_share_room_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, @@ -636,14 +636,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns: list: user_id """ - rows = yield self._simple_select_onecol( + rows = yield self.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self._simple_select_onecol( + pub_rows = yield self.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -674,14 +674,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) f2 USING (room_id) """ - rows = yield self._execute( + rows = yield self.execute( "get_rooms_in_common_for_users", None, sql, user_id, other_user_id ) return [room_id for room_id, in rows] def get_user_directory_stream_pos(self): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", @@ -786,9 +786,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self._execute( - "search_user_dir", self.cursor_to_dict, sql, *args - ) + results = yield self.execute("search_user_dir", self.cursor_to_dict, sql, *args) limited = len(results) > limit diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index aa4f0da5f0..37860af070 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if the user has requested erasure """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", @@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index e0075ccd32..380fd0d107 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -45,13 +45,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def get_all_room_state(self): - return self.store._simple_select_list( + return self.store.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store._simple_select_one( + self.store.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -180,7 +180,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.handler.stats_enabled = True self.store._all_done = False self.get_success( - self.store._simple_update_one( + self.store.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -188,7 +188,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) @@ -205,13 +205,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -656,12 +656,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_delete( + self.store.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store._simple_delete( + self.store.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -675,7 +675,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -685,7 +685,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -695,7 +695,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c5e91a8c41..d5b1c5b4ac 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -158,7 +158,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -184,7 +184,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -193,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -203,7 +203,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -213,7 +213,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 9575058252..124ce0768a 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -632,7 +632,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "state_groups_state", ): count = self.get_success( - self.store._simple_select_one_onecol( + self.store.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 9b81b536f5..7b7434a468 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -356,7 +356,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.get_success( self.storage.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -383,7 +383,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.get_success( self.storage.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index c778de1f0c..de5e4a5fce 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -65,7 +65,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -77,7 +77,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore._simple_select_one_onecol( + value = yield self.datastore.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -106,7 +106,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -122,7 +122,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -137,7 +137,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore._simple_select_list( + ret = yield self.datastore.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -150,7 +150,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -165,7 +165,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_delete_one( + yield self.datastore.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index afac5dec7f..25bdd2c163 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -81,7 +81,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -112,7 +112,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -218,7 +218,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But clear the associated entry in devices table self.get_success( - self.store._simple_update( + self.store.simple_update( table="devices", keyvalues={"user_id": user_id, "device_id": "device_id"}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, @@ -245,7 +245,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -297,7 +297,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -323,7 +323,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should get no results. result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b114c6fb1d..2337a1ae46 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store._simple_delete( + yield self.store.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store._simple_insert( + return self.store.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 4561c3e383..4930b6777e 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -338,7 +338,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -356,7 +356,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 105a0c2b02..d389cf578f 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -132,7 +132,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", diff --git a/tests/unittest.py b/tests/unittest.py index 31997a0f31..295573bc46 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -544,7 +544,7 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore()._simple_insert( + self.hs.get_datastore().simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", -- cgit 1.5.1 From 756d4942f5707922f29fe1fdfd945d73a19d7ac3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Dec 2019 13:52:46 +0000 Subject: Move DB pool and helper functions into dedicated Database class --- scripts/synapse_port_db | 73 +- synapse/app/_base.py | 2 +- synapse/app/user_dir.py | 2 +- synapse/module_api/__init__.py | 2 +- synapse/storage/_base.py | 1468 +------------------ synapse/storage/background_updates.py | 16 +- synapse/storage/data_stores/main/__init__.py | 36 +- synapse/storage/data_stores/main/account_data.py | 26 +- synapse/storage/data_stores/main/appservice.py | 24 +- synapse/storage/data_stores/main/cache.py | 6 +- synapse/storage/data_stores/main/client_ips.py | 24 +- synapse/storage/data_stores/main/deviceinbox.py | 20 +- synapse/storage/data_stores/main/devices.py | 70 +- synapse/storage/data_stores/main/directory.py | 20 +- synapse/storage/data_stores/main/e2e_room_keys.py | 28 +- .../storage/data_stores/main/end_to_end_keys.py | 44 +- .../storage/data_stores/main/event_federation.py | 40 +- .../storage/data_stores/main/event_push_actions.py | 42 +- synapse/storage/data_stores/main/events.py | 90 +- .../storage/data_stores/main/events_bg_updates.py | 24 +- synapse/storage/data_stores/main/events_worker.py | 18 +- synapse/storage/data_stores/main/filtering.py | 4 +- synapse/storage/data_stores/main/group_server.py | 160 +-- synapse/storage/data_stores/main/keys.py | 12 +- .../storage/data_stores/main/media_repository.py | 44 +- .../data_stores/main/monthly_active_users.py | 12 +- synapse/storage/data_stores/main/openid.py | 6 +- synapse/storage/data_stores/main/presence.py | 12 +- synapse/storage/data_stores/main/profile.py | 28 +- synapse/storage/data_stores/main/push_rule.py | 36 +- synapse/storage/data_stores/main/pusher.py | 34 +- synapse/storage/data_stores/main/receipts.py | 36 +- synapse/storage/data_stores/main/registration.py | 164 +-- synapse/storage/data_stores/main/rejections.py | 4 +- synapse/storage/data_stores/main/relations.py | 12 +- synapse/storage/data_stores/main/room.py | 74 +- synapse/storage/data_stores/main/roommember.py | 44 +- synapse/storage/data_stores/main/search.py | 30 +- synapse/storage/data_stores/main/signatures.py | 4 +- synapse/storage/data_stores/main/state.py | 54 +- synapse/storage/data_stores/main/state_deltas.py | 8 +- synapse/storage/data_stores/main/stats.py | 48 +- synapse/storage/data_stores/main/stream.py | 30 +- synapse/storage/data_stores/main/tags.py | 18 +- synapse/storage/data_stores/main/transactions.py | 22 +- synapse/storage/data_stores/main/user_directory.py | 80 +- .../storage/data_stores/main/user_erasure_store.py | 6 +- synapse/storage/database.py | 1485 ++++++++++++++++++++ tests/handlers/test_stats.py | 30 +- tests/handlers/test_user_directory.py | 12 +- tests/rest/admin/test_admin.py | 2 +- tests/storage/test__base.py | 16 +- tests/storage/test_background_update.py | 2 +- tests/storage/test_base.py | 18 +- tests/storage/test_cleanup_extrems.py | 4 +- tests/storage/test_client_ips.py | 12 +- tests/storage/test_event_federation.py | 8 +- tests/storage/test_event_push_actions.py | 12 +- tests/storage/test_monthly_active_users.py | 6 +- tests/storage/test_redaction.py | 4 +- tests/storage/test_roommember.py | 2 +- tests/unittest.py | 2 +- 62 files changed, 2377 insertions(+), 2295 deletions(-) create mode 100644 synapse/storage/database.py (limited to 'tests') diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index c4cf11d19a..7a2e177d3d 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -173,14 +173,14 @@ class Store( return (yield self.db_pool.runWithConnection(r)) def execute(self, f, *args, **kwargs): - return self.runInteraction(f.__name__, f, *args, **kwargs) + return self.db.runInteraction(f.__name__, f, *args, **kwargs) def execute_sql(self, sql, *args): def r(txn): txn.execute(sql, args) return txn.fetchall() - return self.runInteraction("execute_sql", r) + return self.db.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( @@ -223,7 +223,7 @@ class Porter(object): def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = yield self.postgres_store.simple_select_one( + row = yield self.postgres_store.db.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -233,12 +233,14 @@ class Porter(object): total_to_port = None if row is None: if table == "sent_transactions": - forward_chunk, already_ported, total_to_port = ( - yield self._setup_sent_transactions() - ) + ( + forward_chunk, + already_ported, + total_to_port, + ) = yield self._setup_sent_transactions() backward_chunk = 0 else: - yield self.postgres_store.simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -268,7 +270,7 @@ class Porter(object): yield self.postgres_store.execute(delete_all) - yield self.postgres_store.simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -322,7 +324,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - yield self.postgres_store.simple_insert( + yield self.postgres_store.db.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -363,7 +365,7 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = yield self.sqlite_store.runInteraction("select", r) + headers, frows, brows = yield self.sqlite_store.db.runInteraction("select", r) if frows or brows: if frows: @@ -377,7 +379,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store.simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -416,7 +418,7 @@ class Porter(object): return headers, rows - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = yield self.sqlite_store.db.runInteraction("select", r) if rows: forward_chunk = rows[-1][0] + 1 @@ -433,8 +435,8 @@ class Porter(object): rows_dict = [] for row in rows: d = dict(zip(headers, row)) - if "\0" in d['value']: - logger.warning('dropping search row %s', d) + if "\0" in d["value"]: + logger.warning("dropping search row %s", d) else: rows_dict.append(d) @@ -454,7 +456,7 @@ class Porter(object): ], ) - self.postgres_store.simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -504,17 +506,14 @@ class Porter(object): self.progress.set_state("Preparing %s" % config["name"]) conn = self.setup_db(config, engine) - db_pool = adbapi.ConnectionPool( - config["name"], **config["args"] - ) + db_pool = adbapi.ConnectionPool(config["name"], **config["args"]) hs = MockHomeserver(self.hs_config, engine, conn, db_pool) store = Store(conn, hs) - yield store.runInteraction( - "%s_engine.check_database" % config["name"], - engine.check_database, + yield store.db.runInteraction( + "%s_engine.check_database" % config["name"], engine.check_database, ) return store @@ -541,7 +540,9 @@ class Porter(object): self.sqlite_store = yield self.build_db_store(self.sqlite_config) # Check if all background updates are done, abort if not. - updates_complete = yield self.sqlite_store.has_completed_background_updates() + updates_complete = ( + yield self.sqlite_store.has_completed_background_updates() + ) if not updates_complete: sys.stderr.write( "Pending background updates exist in the SQLite3 database." @@ -582,22 +583,22 @@ class Porter(object): ) try: - yield self.postgres_store.runInteraction("alter_table", alter_table) + yield self.postgres_store.db.runInteraction("alter_table", alter_table) except Exception: # On Error Resume Next pass - yield self.postgres_store.runInteraction( + yield self.postgres_store.db.runInteraction( "create_port_table", create_port_table ) # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = yield self.sqlite_store.simple_select_onecol( + sqlite_tables = yield self.sqlite_store.db.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = yield self.postgres_store.simple_select_onecol( + postgres_tables = yield self.postgres_store.db.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -687,11 +688,11 @@ class Porter(object): rows = txn.fetchall() headers = [column[0] for column in txn.description] - ts_ind = headers.index('ts') + ts_ind = headers.index("ts") return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = yield self.sqlite_store.db.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -724,7 +725,7 @@ class Porter(object): next_chunk = yield self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - yield self.postgres_store.simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", @@ -737,7 +738,7 @@ class Porter(object): txn.execute( "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) ) - size, = txn.fetchone() + (size,) = txn.fetchone() return int(size) remaining_count = yield self.sqlite_store.execute(get_sent_table_size) @@ -790,7 +791,7 @@ class Porter(object): next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.runInteraction("setup_state_group_id_seq", r) + return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) ############################################## @@ -871,7 +872,7 @@ class CursesProgress(Progress): duration = int(now) - int(self.start_time) minutes, seconds = divmod(duration, 60) - duration_str = '%02dm %02ds' % (minutes, seconds) + duration_str = "%02dm %02ds" % (minutes, seconds) if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) @@ -881,7 +882,7 @@ class CursesProgress(Progress): left = float(self.total_remaining) / self.total_processed est_remaining = (int(now) - self.start_time) * left - est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) + est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60) else: est_remaining_str = "Unknown" status = "Time spent: %s (est. remaining: %s)" % ( @@ -967,7 +968,7 @@ if __name__ == "__main__": description="A script to port an existing synapse SQLite database to" " a new PostgreSQL database." ) - parser.add_argument("-v", action='store_true') + parser.add_argument("-v", action="store_true") parser.add_argument( "--sqlite-database", required=True, @@ -976,12 +977,12 @@ if __name__ == "__main__": ) parser.add_argument( "--postgres-config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), required=True, help="The database config file for the PostgreSQL database", ) parser.add_argument( - "--curses", action='store_true', help="display a curses based progress UI" + "--curses", action="store_true", help="display a curses based progress UI" ) parser.add_argument( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 2ac7d5c064..9c96816096 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -269,7 +269,7 @@ def start(hs, listeners=None): # It is now safe to start your Synapse. hs.start_listening(listeners) - hs.get_datastore().start_profiling() + hs.get_datastore().db.start_profiling() setup_sentry(hs) setup_sdnotify(hs) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 0fa2b50999..b6d4481725 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -64,7 +64,7 @@ class UserDirectorySlaveStore( super(UserDirectorySlaveStore, self).__init__(db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 735b882363..305b9b0178 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -175,4 +175,4 @@ class ModuleApi(object): Returns: Deferred[object]: result of func """ - return self._store.runInteraction(desc, func, *args, **kwargs) + return self._store.db.runInteraction(desc, func, *args, **kwargs) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 9205e550bb..fd5bb3e1de 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -16,1304 +16,28 @@ # limitations under the License. import logging import random -import sys -import time -from typing import Iterable, Tuple -from six import PY2, iteritems, iterkeys, itervalues -from six.moves import builtins, intern, range +from six import PY2 +from six.moves import builtins from canonicaljson import json -from prometheus_client import Histogram -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.logging.context import LoggingContext, make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.database import LoggingTransaction # noqa: F401 +from synapse.storage.database import make_in_list_sql_clause # noqa: F401 +from synapse.storage.database import Database from synapse.types import get_domain_from_id -from synapse.util.stringutils import exception_to_unicode - -# import a function which will return a monotonic time, in seconds -try: - # on python 3, use time.monotonic, since time.clock can go backwards - from time import monotonic as monotonic_time -except ImportError: - # ... but python 2 doesn't have it - from time import clock as monotonic_time logger = logging.getLogger(__name__) -try: - MAX_TXN_ID = sys.maxint - 1 -except AttributeError: - # python 3 does not have a maximum int value - MAX_TXN_ID = 2 ** 63 - 1 - -sql_logger = logging.getLogger("synapse.storage.SQL") -transaction_logger = logging.getLogger("synapse.storage.txn") -perf_logger = logging.getLogger("synapse.storage.TIME") - -sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") - -sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) -sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) - - -# Unique indexes which have been added in background updates. Maps from table name -# to the name of the background update which added the unique index to that table. -# -# This is used by the upsert logic to figure out which tables are safe to do a proper -# UPSERT on: until the relevant background update has completed, we -# have to emulate an upsert by locking the table. -# -UNIQUE_INDEX_BACKGROUND_UPDATES = { - "user_ips": "user_ips_device_unique_index", - "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", - "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", - "event_search": "event_search_event_id_idx", -} - - -class LoggingTransaction(object): - """An object that almost-transparently proxies for the 'txn' object - passed to the constructor. Adds logging and metrics to the .execute() - method. - - Args: - txn: The database transcation object to wrap. - name (str): The name of this transactions for logging. - database_engine (Sqlite3Engine|PostgresEngine) - after_callbacks(list|None): A list that callbacks will be appended to - that have been added by `call_after` which should be run on - successful completion of the transaction. None indicates that no - callbacks should be allowed to be scheduled to run. - exception_callbacks(list|None): A list that callbacks will be appended - to that have been added by `call_on_exception` which should be run - if transaction ends with an error. None indicates that no callbacks - should be allowed to be scheduled to run. - """ - - __slots__ = [ - "txn", - "name", - "database_engine", - "after_callbacks", - "exception_callbacks", - ] - - def __init__( - self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None - ): - object.__setattr__(self, "txn", txn) - object.__setattr__(self, "name", name) - object.__setattr__(self, "database_engine", database_engine) - object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "exception_callbacks", exception_callbacks) - - def call_after(self, callback, *args, **kwargs): - """Call the given callback on the main twisted thread after the - transaction has finished. Used to invalidate the caches on the - correct thread. - """ - self.after_callbacks.append((callback, args, kwargs)) - - def call_on_exception(self, callback, *args, **kwargs): - self.exception_callbacks.append((callback, args, kwargs)) - - def __getattr__(self, name): - return getattr(self.txn, name) - - def __setattr__(self, name, value): - setattr(self.txn, name, value) - - def __iter__(self): - return self.txn.__iter__() - - def execute_batch(self, sql, args): - if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch - - self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) - else: - for val in args: - self.execute(sql, val) - - def execute(self, sql, *args): - self._do_execute(self.txn.execute, sql, *args) - - def executemany(self, sql, *args): - self._do_execute(self.txn.executemany, sql, *args) - - def _make_sql_one_line(self, sql): - "Strip newlines out of SQL so that the loggers in the DB are on one line" - return " ".join(l.strip() for l in sql.splitlines() if l.strip()) - - def _do_execute(self, func, sql, *args): - sql = self._make_sql_one_line(sql) - - # TODO(paul): Maybe use 'info' and 'debug' for values? - sql_logger.debug("[SQL] {%s} %s", self.name, sql) - - sql = self.database_engine.convert_param_style(sql) - if args: - try: - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) - except Exception: - # Don't let logging failures stop SQL from working - pass - - start = time.time() - - try: - return func(sql, *args) - except Exception as e: - logger.debug("[SQL FAIL] {%s} %s", self.name, e) - raise - finally: - secs = time.time() - start - sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) - sql_query_timer.labels(sql.split()[0]).observe(secs) - - -class PerformanceCounters(object): - def __init__(self): - self.current_counters = {} - self.previous_counters = {} - - def update(self, key, duration_secs): - count, cum_time = self.current_counters.get(key, (0, 0)) - count += 1 - cum_time += duration_secs - self.current_counters[key] = (count, cum_time) - - def interval(self, interval_duration_secs, limit=3): - counters = [] - for name, (count, cum_time) in iteritems(self.current_counters): - prev_count, prev_time = self.previous_counters.get(name, (0, 0)) - counters.append( - ( - (cum_time - prev_time) / interval_duration_secs, - count - prev_count, - name, - ) - ) - - self.previous_counters = dict(self.current_counters) - - counters.sort(reverse=True) - - top_n_counters = ", ".join( - "%s(%d): %.3f%%" % (name, count, 100 * ratio) - for ratio, count, name in counters[:limit] - ) - - return top_n_counters - class SQLBaseStore(object): - _TXN_ID = 0 - def __init__(self, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() - - self._previous_txn_total_time = 0 - self._current_txn_total_time = 0 - self._previous_loop_ts = 0 - - # TODO(paul): These can eventually be removed once the metrics code - # is running in mainline, and we have some nice monitoring frontends - # to watch it - self._txn_perf_counters = PerformanceCounters() - self.database_engine = hs.database_engine - - # A set of tables that are not safe to use native upserts in. - self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) - - # We add the user_directory_search table to the blacklist on SQLite - # because the existing search table does not have an index, making it - # unsafe to use native upserts. - if isinstance(self.database_engine, Sqlite3Engine): - self._unsafe_to_upsert_tables.add("user_directory_search") - - if self.database_engine.can_native_upsert: - # Check ASAP (and then later, every 1s) to see if we have finished - # background updates of tables that aren't safe to update. - self._clock.call_later( - 0.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) - + self.db = Database(hs) self.rand = random.SystemRandom() - @defer.inlineCallbacks - def _check_safe_to_upsert(self): - """ - Is it safe to use native UPSERT? - - If there are background updates, we will need to wait, as they may be - the addition of indexes that set the UNIQUE constraint that we require. - - If the background updates have not completed, wait 15 sec and check again. - """ - updates = yield self.simple_select_list( - "background_updates", - keyvalues=None, - retcols=["update_name"], - desc="check_background_updates", - ) - updates = [x["update_name"] for x in updates] - - for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): - if update_name not in updates: - logger.debug("Now safe to upsert in %s", table) - self._unsafe_to_upsert_tables.discard(table) - - # If there's any updates still running, reschedule to run. - if updates: - self._clock.call_later( - 15.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) - - def start_profiling(self): - self._previous_loop_ts = monotonic_time() - - def loop(): - curr = self._current_txn_total_time - prev = self._previous_txn_total_time - self._previous_txn_total_time = curr - - time_now = monotonic_time() - time_then = self._previous_loop_ts - self._previous_loop_ts = time_now - - duration = time_now - time_then - ratio = (curr - prev) / duration - - top_three_counters = self._txn_perf_counters.interval(duration, limit=3) - - perf_logger.info( - "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters - ) - - self._clock.looping_call(loop, 10000) - - def new_transaction( - self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs - ): - start = monotonic_time() - txn_id = self._TXN_ID - - # We don't really need these to be unique, so lets stop it from - # growing really large. - self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) - - name = "%s-%x" % (desc, txn_id) - - transaction_logger.debug("[TXN START] {%s}", name) - - try: - i = 0 - N = 5 - while True: - cursor = LoggingTransaction( - conn.cursor(), - name, - self.database_engine, - after_callbacks, - exception_callbacks, - ) - try: - r = func(cursor, *args, **kwargs) - conn.commit() - return r - except self.database_engine.module.OperationalError as e: - # This can happen if the database disappears mid - # transaction. - logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", - name, - exception_to_unicode(e), - i, - N, - ) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) - ) - continue - raise - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", - name, - exception_to_unicode(e1), - ) - continue - raise - finally: - # we're either about to retry with a new cursor, or we're about to - # release the connection. Once we release the connection, it could - # get used for another query, which might do a conn.rollback(). - # - # In the latter case, even though that probably wouldn't affect the - # results of this transaction, python's sqlite will reset all - # statements on the connection [1], which will make our cursor - # invalid [2]. - # - # In any case, continuing to read rows after commit()ing seems - # dubious from the PoV of ACID transactional semantics - # (sqlite explicitly says that once you commit, you may see rows - # from subsequent updates.) - # - # In psycopg2, cursors are essentially a client-side fabrication - - # all the data is transferred to the client side when the statement - # finishes executing - so in theory we could go on streaming results - # from the cursor, but attempting to do so would make us - # incompatible with sqlite, so let's make sure we're not doing that - # by closing the cursor. - # - # (*named* cursors in psycopg2 are different and are proper server- - # side things, but (a) we don't use them and (b) they are implicitly - # closed by ending the transaction anyway.) - # - # In short, if we haven't finished with the cursor yet, that's a - # problem waiting to bite us. - # - # TL;DR: we're done with the cursor, so we can close it. - # - # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 - # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 - cursor.close() - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) - raise - finally: - end = monotonic_time() - duration = end - start - - LoggingContext.current_context().add_database_transaction(duration) - - transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) - - self._current_txn_total_time += duration - self._txn_perf_counters.update(desc, duration) - sql_txn_timer.labels(desc).observe(duration) - - @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): - """Starts a transaction on the database and runs a given function - - Arguments: - desc (str): description of the transaction, for logging and metrics - func (func): callback function, which will be called with a - database transaction (twisted.enterprise.adbapi.Transaction) as - its first argument, followed by `args` and `kwargs`. - - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` - - Returns: - Deferred: The result of func - """ - after_callbacks = [] - exception_callbacks = [] - - if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warning("Starting db txn '%s' from sentinel context", desc) - - try: - result = yield self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - **kwargs - ) - - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except: # noqa: E722, as we reraise the exception this is fine. - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise - - return result - - @defer.inlineCallbacks - def runWithConnection(self, func, *args, **kwargs): - """Wraps the .runWithConnection() method on the underlying db_pool. - - Arguments: - func (func): callback function, which will be called with a - database connection (twisted.enterprise.adbapi.Connection) as - its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` - - Returns: - Deferred: The result of func - """ - parent_context = LoggingContext.current_context() - if parent_context == LoggingContext.sentinel: - logger.warning( - "Starting db connection from sentinel context: metrics will be lost" - ) - parent_context = None - - start_time = monotonic_time() - - def inner_func(conn, *args, **kwargs): - with LoggingContext("runWithConnection", parent_context) as context: - sched_duration_sec = monotonic_time() - start_time - sql_scheduling_timer.observe(sched_duration_sec) - context.add_database_scheduled(sched_duration_sec) - - if self.database_engine.is_connection_closed(conn): - logger.debug("Reconnecting closed database connection") - conn.reconnect() - - return func(conn, *args, **kwargs) - - result = yield make_deferred_yieldable( - self._db_pool.runWithConnection(inner_func, *args, **kwargs) - ) - - return result - - @staticmethod - def cursor_to_dict(cursor): - """Converts a SQL cursor into an list of dicts. - - Args: - cursor : The DBAPI cursor which has executed a query. - Returns: - A list of dicts where the key is the column header. - """ - col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list(dict(zip(col_headers, row)) for row in cursor) - return results - - def execute(self, desc, decoder, query, *args): - """Runs a single query for a result set. - - Args: - decoder - The function which can resolve the cursor results to - something meaningful. - query - The query string to execute - *args - Query args. - Returns: - The result of decoder(results) - """ - - def interaction(txn): - txn.execute(query, args) - if decoder: - return decoder(txn) - else: - return txn.fetchall() - - return self.runInteraction(desc, interaction) - - # "Simple" SQL API methods that operate on a single table with no JOINs, - # no complex WHERE clauses, just a dict of values for columns. - - @defer.inlineCallbacks - def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): - """Executes an INSERT query on the named table. - - Args: - table : string giving the table name - values : dict of new column names and values for them - or_ignore : bool stating whether an exception should be raised - when a conflicting row already exists. If True, False will be - returned by the function instead - desc : string giving a description of the transaction - - Returns: - bool: Whether the row was inserted or not. Only useful when - `or_ignore` is True - """ - try: - yield self.runInteraction(desc, self.simple_insert_txn, table, values) - except self.database_engine.module.IntegrityError: - # We have to do or_ignore flag at this layer, since we can't reuse - # a cursor after we receive an error from the db. - if not or_ignore: - raise - return False - return True - - @staticmethod - def simple_insert_txn(txn, table, values): - keys, vals = zip(*values.items()) - - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys), - ", ".join("?" for _ in keys), - ) - - txn.execute(sql, vals) - - def simple_insert_many(self, table, values, desc): - return self.runInteraction(desc, self.simple_insert_many_txn, table, values) - - @staticmethod - def simple_insert_many_txn(txn, table, values): - if not values: - return - - # This is a *slight* abomination to get a list of tuples of key names - # and a list of tuples of value names. - # - # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] - # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] - # - # The sort is to ensure that we don't rely on dictionary iteration - # order. - keys, vals = zip( - *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] - ) - - for k in keys: - if k != keys[0]: - raise RuntimeError("All items must have the same keys") - - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), - ) - - txn.executemany(sql, vals) - - @defer.inlineCallbacks - def simple_upsert( - self, - table, - keyvalues, - values, - insertion_values={}, - desc="simple_upsert", - lock=True, - ): - """ - - `lock` should generally be set to True (the default), but can be set - to False if either of the following are true: - - * there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - - * we somehow know that we are the only thread which will be updating - this table. - - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key columns and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - Deferred(None or bool): Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. - """ - attempts = 0 - while True: - try: - result = yield self.runInteraction( - desc, - self.simple_upsert_txn, - table, - keyvalues, - values, - insertion_values, - lock=lock, - ) - return result - except self.database_engine.module.IntegrityError as e: - attempts += 1 - if attempts >= 5: - # don't retry forever, because things other than races - # can cause IntegrityErrors - raise - - # presumably we raced with another transaction: let's retry. - logger.warning( - "IntegrityError when upserting into %s; retrying: %s", table, e - ) - - def simple_upsert_txn( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): - """ - Pick the UPSERT method which works best on the platform. Either the - native one (Pg9.5+, recent SQLites), or fall back to an emulated method. - - Args: - txn: The transaction to use. - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - None or bool: Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. - """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): - return self.simple_upsert_txn_native_upsert( - txn, table, keyvalues, values, insertion_values=insertion_values - ) - else: - return self.simple_upsert_txn_emulated( - txn, - table, - keyvalues, - values, - insertion_values=insertion_values, - lock=lock, - ) - - def simple_upsert_txn_emulated( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): - """ - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - bool: Return True if a new entry was created, False if an existing - one was updated. - """ - # We need to lock the table :(, unless we're *really* careful - if lock: - self.database_engine.lock_table(txn, table) - - def _getwhere(key): - # If the value we're passing in is None (aka NULL), we need to use - # IS, not =, as NULL = NULL equals NULL (False). - if keyvalues[key] is None: - return "%s IS ?" % (key,) - else: - return "%s = ?" % (key,) - - if not values: - # If `values` is empty, then all of the values we care about are in - # the unique key, so there is nothing to UPDATE. We can just do a - # SELECT instead to see if it exists. - sql = "SELECT 1 FROM %s WHERE %s" % ( - table, - " AND ".join(_getwhere(k) for k in keyvalues), - ) - sqlargs = list(keyvalues.values()) - txn.execute(sql, sqlargs) - if txn.fetchall(): - # We have an existing record. - return False - else: - # First try to update. - sql = "UPDATE %s SET %s WHERE %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in values), - " AND ".join(_getwhere(k) for k in keyvalues), - ) - sqlargs = list(values.values()) + list(keyvalues.values()) - - txn.execute(sql, sqlargs) - if txn.rowcount > 0: - # successfully updated at least one row. - return False - - # We didn't find any existing rows, so insert a new one - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(values) - allvalues.update(insertion_values) - - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues), - ) - txn.execute(sql, list(allvalues.values())) - # successfully inserted - return True - - def simple_upsert_txn_native_upsert( - self, txn, table, keyvalues, values, insertion_values={} - ): - """ - Use the native UPSERT functionality in recent PostgreSQL versions. - - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - Returns: - None - """ - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(insertion_values) - - if not values: - latter = "NOTHING" - else: - allvalues.update(values) - latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - - sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues), - ", ".join(k for k in keyvalues), - latter, - ) - txn.execute(sql, list(allvalues.values())) - - def simple_upsert_many_txn( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): - return self.simple_upsert_many_txn_native_upsert( - txn, table, key_names, key_values, value_names, value_values - ) - else: - return self.simple_upsert_many_txn_emulated( - txn, table, key_names, key_values, value_names, value_values - ) - - def simple_upsert_many_txn_emulated( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times, but without native UPSERT support or batching. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - # No value columns, therefore make a blank list so that the following - # zip() works correctly. - if not value_names: - value_values = [() for x in range(len(key_values))] - - for keyv, valv in zip(key_values, value_values): - _keys = {x: y for x, y in zip(key_names, keyv)} - _vals = {x: y for x, y in zip(value_names, valv)} - - self.simple_upsert_txn_emulated(txn, table, _keys, _vals) - - def simple_upsert_many_txn_native_upsert( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times, using batching where possible. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - allnames = [] - allnames.extend(key_names) - allnames.extend(value_names) - - if not value_names: - # No value columns, therefore make a blank list so that the - # following zip() works correctly. - latter = "NOTHING" - value_values = [() for x in range(len(key_values))] - else: - latter = "UPDATE SET " + ", ".join( - k + "=EXCLUDED." + k for k in value_names - ) - - sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( - table, - ", ".join(k for k in allnames), - ", ".join("?" for _ in allnames), - ", ".join(key_names), - latter, - ) - - args = [] - - for x, y in zip(key_values, value_values): - args.append(tuple(x) + tuple(y)) - - return txn.execute_batch(sql, args) - - def simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" - ): - """Executes a SELECT query on the named table, which is expected to - return a single row, returning multiple columns from it. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcols : list of strings giving the names of the columns to return - - allow_none : If true, return None instead of failing if the SELECT - statement returns no rows - """ - return self.runInteraction( - desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none - ) - - def simple_select_one_onecol( - self, - table, - keyvalues, - retcol, - allow_none=False, - desc="simple_select_one_onecol", - ): - """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcol : string giving the name of the column to return - """ - return self.runInteraction( - desc, - self.simple_select_one_onecol_txn, - table, - keyvalues, - retcol, - allow_none=allow_none, - ) - - @classmethod - def simple_select_one_onecol_txn( - cls, txn, table, keyvalues, retcol, allow_none=False - ): - ret = cls.simple_select_onecol_txn( - txn, table=table, keyvalues=keyvalues, retcol=retcol - ) - - if ret: - return ret[0] - else: - if allow_none: - return None - else: - raise StoreError(404, "No row found") - - @staticmethod - def simple_select_onecol_txn(txn, table, keyvalues, retcol): - sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} - - if keyvalues: - sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) - txn.execute(sql, list(keyvalues.values())) - else: - txn.execute(sql) - - return [r[0] for r in txn] - - def simple_select_onecol( - self, table, keyvalues, retcol, desc="simple_select_onecol" - ): - """Executes a SELECT query on the named table, which returns a list - comprising of the values of the named column from the selected rows. - - Args: - table (str): table name - keyvalues (dict|None): column names and values to select the rows with - retcol (str): column whos value we wish to retrieve. - - Returns: - Deferred: Results in a list - """ - return self.runInteraction( - desc, self.simple_select_onecol_txn, table, keyvalues, retcol - ) - - def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - table (str): the table name - keyvalues (dict[str, Any] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, self.simple_select_list_txn, table, keyvalues, retcols - ) - - @classmethod - def simple_select_list_txn(cls, txn, table, keyvalues, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return - """ - if keyvalues: - sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - txn.execute(sql, list(keyvalues.values())) - else: - sql = "SELECT %s FROM %s" % (", ".join(retcols), table) - txn.execute(sql) - - return cls.cursor_to_dict(txn) - - @defer.inlineCallbacks - def simple_select_many_batch( - self, - table, - column, - iterable, - retcols, - keyvalues={}, - desc="simple_select_many_batch", - batch_size=100, - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Filters rows by if value of `column` is in `iterable`. - - Args: - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return - """ - results = [] - - if not iterable: - return results - - # iterables can not be sliced, so convert it to a list first - it_list = list(iterable) - - chunks = [ - it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) - ] - for chunk in chunks: - rows = yield self.runInteraction( - desc, - self.simple_select_many_txn, - table, - column, - chunk, - keyvalues, - retcols, - ) - - results.extend(rows) - - return results - - @classmethod - def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Filters rows by if value of `column` is in `iterable`. - - Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return - """ - if not iterable: - return [] - - clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) - clauses = [clause] - - for key, value in iteritems(keyvalues): - clauses.append("%s = ?" % (key,)) - values.append(value) - - sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join(clauses), - ) - - txn.execute(sql, values) - return cls.cursor_to_dict(txn) - - def simple_update(self, table, keyvalues, updatevalues, desc): - return self.runInteraction( - desc, self.simple_update_txn, table, keyvalues, updatevalues - ) - - @staticmethod - def simple_update_txn(txn, table, keyvalues, updatevalues): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) - else: - where = "" - - update_sql = "UPDATE %s SET %s %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in updatevalues), - where, - ) - - txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) - - return txn.rowcount - - def simple_update_one( - self, table, keyvalues, updatevalues, desc="simple_update_one" - ): - """Executes an UPDATE query on the named table, setting new values for - columns in a row matching the key values. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - updatevalues : dict giving column names and values to update - retcols : optional list of column names to return - - If present, retcols gives a list of column names on which to perform - a SELECT statement *before* performing the UPDATE statement. The values - of these will be returned in a dict. - - These are performed within the same transaction, allowing an atomic - get-and-set. This can be used to implement compare-and-set by putting - the update column in the 'keyvalues' dict as well. - """ - return self.runInteraction( - desc, self.simple_update_one_txn, table, keyvalues, updatevalues - ) - - @classmethod - def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): - rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) - - if rowcount == 0: - raise StoreError(404, "No row found (%s)" % (table,)) - if rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - @staticmethod - def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): - select_sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(select_sql, list(keyvalues.values())) - row = txn.fetchone() - - if not row: - if allow_none: - return None - raise StoreError(404, "No row found (%s)" % (table,)) - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - return dict(zip(retcols, row)) - - def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): - """Executes a DELETE query on the named table, expecting to delete a - single row. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) - - @staticmethod - def simple_delete_one_txn(txn, table, keyvalues): - """Executes a DELETE query on the named table, expecting to delete a - single row. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - sql = "DELETE FROM %s WHERE %s" % ( - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(sql, list(keyvalues.values())) - if txn.rowcount == 0: - raise StoreError(404, "No row found (%s)" % (table,)) - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - def simple_delete(self, table, keyvalues, desc): - return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) - - @staticmethod - def simple_delete_txn(txn, table, keyvalues): - sql = "DELETE FROM %s WHERE %s" % ( - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(sql, list(keyvalues.values())) - return txn.rowcount - - def simple_delete_many(self, table, column, iterable, keyvalues, desc): - return self.runInteraction( - desc, self.simple_delete_many_txn, table, column, iterable, keyvalues - ) - - @staticmethod - def simple_delete_many_txn(txn, table, column, iterable, keyvalues): - """Executes a DELETE query on the named table. - - Filters rows by if value of `column` is in `iterable`. - - Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - - Returns: - int: Number rows deleted - """ - if not iterable: - return 0 - - sql = "DELETE FROM %s" % table - - clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) - clauses = [clause] - - for key, value in iteritems(keyvalues): - clauses.append("%s = ?" % (key,)) - values.append(value) - - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) - txn.execute(sql, values) - - return txn.rowcount - - def get_cache_dict( - self, db_conn, table, entity_column, stream_column, max_value, limit=100000 - ): - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - %(limit)s" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - "limit": limit, - } - - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (int(max_value),)) - - cache = {row[0]: int(row[1]) for row in txn} - - txn.close() - - if cache: - min_val = min(itervalues(cache)) - else: - min_val = max_value - - return cache, min_val - def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does not stream invalidations down replication. @@ -1347,159 +71,6 @@ class SQLBaseStore(object): # which is fine. pass - def simple_select_list_paginate( - self, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction="ASC", - desc="simple_select_list_paginate", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, - self.simple_select_list_paginate_txn, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction=order_direction, - ) - - @classmethod - def simple_select_list_paginate_txn( - cls, - txn, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction="ASC", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - if order_direction not in ["ASC", "DESC"]: - raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") - - if keyvalues: - where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) - else: - where_clause = "" - - sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( - ", ".join(retcols), - table, - where_clause, - orderby, - order_direction, - ) - txn.execute(sql, list(keyvalues.values()) + [limit, start]) - - return cls.cursor_to_dict(txn) - - def get_user_count_txn(self, txn): - """Get a total number of registered users in the users list. - - Args: - txn : Transaction object - Returns: - int : number of users - """ - sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" - txn.execute(sql_count) - return txn.fetchone()[0] - - def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None - """ - - return self.runInteraction( - desc, self.simple_search_list_txn, table, term, col, retcols - ) - - @classmethod - def simple_search_list_txn(cls, txn, table, term, col, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None - """ - if term: - sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) - termvalues = ["%%" + term + "%%"] - txn.execute(sql, termvalues) - else: - return 0 - - return cls.cursor_to_dict(txn) - - -class _RollbackButIsFineException(Exception): - """ This exception is used to rollback a transaction without implying - something went wrong. - """ - - pass - def db_to_json(db_content): """ @@ -1528,30 +99,3 @@ def db_to_json(db_content): except Exception: logging.warning("Tried to decode '%r' as JSON and failed", db_content) raise - - -def make_in_list_sql_clause( - database_engine, column: str, iterable: Iterable -) -> Tuple[str, Iterable]: - """Returns an SQL clause that checks the given column is in the iterable. - - On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres - it expands to `column = ANY(?)`. While both DBs support the `IN` form, - using the `ANY` form on postgres means that it views queries with - different length iterables as the same, helping the query stats. - - Args: - database_engine - column: Name of the column - iterable: The values to check the column against. - - Returns: - A tuple of SQL query and the args - """ - - if database_engine.supports_using_any_list: - # This should hopefully be faster, but also makes postgres query - # stats easier to understand. - return "%s = ANY(?)" % (column,), [list(iterable)] - else: - return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 06955a0537..dfca94b0e0 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -139,7 +139,7 @@ class BackgroundUpdateStore(SQLBaseStore): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self.simple_select_onecol( + updates = yield self.db.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -161,7 +161,7 @@ class BackgroundUpdateStore(SQLBaseStore): if update_name in self._background_update_queue: return False - update_exists = await self.simple_select_one_onecol( + update_exists = await self.db.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="1", @@ -184,7 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore): no more work to do. """ if not self._background_update_queue: - updates = yield self.simple_select_list( + updates = yield self.db.simple_select_list( "background_updates", keyvalues=None, retcols=("update_name", "depends_on"), @@ -226,7 +226,7 @@ class BackgroundUpdateStore(SQLBaseStore): else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = yield self.simple_select_one_onecol( + progress_json = yield self.db.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", @@ -391,7 +391,7 @@ class BackgroundUpdateStore(SQLBaseStore): def updater(progress, batch_size): if runner is not None: logger.info("Adding index %s to %s", index_name, table) - yield self.runWithConnection(runner) + yield self.db.runWithConnection(runner) yield self._end_background_update(update_name) return 1 @@ -413,7 +413,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [] progress_json = json.dumps(progress) - return self.simple_insert( + return self.db.simple_insert( "background_updates", {"update_name": update_name, "progress_json": progress_json}, ) @@ -429,7 +429,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [ name for name in self._background_update_queue if name != update_name ] - return self.simple_delete_one( + return self.db.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) @@ -444,7 +444,7 @@ class BackgroundUpdateStore(SQLBaseStore): progress_json = json.dumps(progress) - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 2a5b33dda1..46f0f26af6 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -171,9 +171,11 @@ class DataStore( else: self._cache_id_gen = None + super(DataStore, self).__init__(db_conn, hs) + self._presence_on_startup = self._get_active_presence(db_conn) - presence_cache_prefill, min_presence_val = self.get_cache_dict( + presence_cache_prefill, min_presence_val = self.db.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", @@ -187,7 +189,7 @@ class DataStore( ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self.get_cache_dict( + device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", @@ -202,7 +204,7 @@ class DataStore( ) # The federation outbox and the local device inbox uses the same # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self.get_cache_dict( + device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", @@ -228,7 +230,7 @@ class DataStore( ) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -242,7 +244,7 @@ class DataStore( prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self.get_cache_dict( + _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", @@ -262,8 +264,6 @@ class DataStore( # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() - super(DataStore, self).__init__(db_conn, hs) - def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None @@ -283,7 +283,7 @@ class DataStore( txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) txn.close() for row in rows: @@ -296,7 +296,7 @@ class DataStore( Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.runInteraction("count_daily_users", self._count_users, yesterday) + return self.db.runInteraction("count_daily_users", self._count_users, yesterday) def count_monthly_users(self): """ @@ -306,7 +306,7 @@ class DataStore( amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.runInteraction( + return self.db.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -406,7 +406,7 @@ class DataStore( return results - return self.runInteraction("count_r30_users", _count_r30_users) + return self.db.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ @@ -471,7 +471,7 @@ class DataStore( # frequently self._last_user_visit_update = now - return self.runInteraction( + return self.db.runInteraction( "generate_user_daily_visits", _generate_user_daily_visits ) @@ -482,7 +482,7 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self.simple_select_list( + return self.db.simple_select_list( table="users", keyvalues={}, retcols=["name", "password_hash", "is_guest", "admin", "user_type"], @@ -502,9 +502,9 @@ class DataStore( Returns: defer.Deferred: resolves to json object {list[dict[str, Any]], count} """ - users = yield self.runInteraction( + users = yield self.db.runInteraction( "get_users_paginate", - self.simple_select_list_paginate_txn, + self.db.simple_select_list_paginate_txn, table="users", keyvalues={"is_guest": False}, orderby=order, @@ -512,7 +512,9 @@ class DataStore( limit=limit, retcols=["name", "password_hash", "is_guest", "admin", "user_type"], ) - count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) + count = yield self.db.runInteraction( + "get_users_paginate", self.get_user_count_txn + ) retval = {"users": users, "total": count} return retval @@ -526,7 +528,7 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self.simple_search_list( + return self.db.simple_search_list( table="users", term=term, col="name", diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index b0d22faf3f..a96fe9485c 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -67,7 +67,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_user_txn(txn): - rows = self.simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, @@ -78,7 +78,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - rows = self.simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, @@ -92,7 +92,7 @@ class AccountDataWorkerStore(SQLBaseStore): return global_account_data, by_room - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @@ -102,7 +102,7 @@ class AccountDataWorkerStore(SQLBaseStore): Returns: Deferred: A dict """ - result = yield self.simple_select_one_onecol( + result = yield self.db.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -127,7 +127,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_txn(txn): - rows = self.simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, @@ -138,7 +138,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @@ -156,7 +156,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_and_type_txn(txn): - content_json = self.simple_select_one_onecol_txn( + content_json = self.db.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ @@ -170,7 +170,7 @@ class AccountDataWorkerStore(SQLBaseStore): return json.loads(content_json) if content_json else None - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) @@ -207,7 +207,7 @@ class AccountDataWorkerStore(SQLBaseStore): room_results = txn.fetchall() return global_results, room_results - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_account_data_txn", get_updated_account_data_txn ) @@ -252,7 +252,7 @@ class AccountDataWorkerStore(SQLBaseStore): if not changed: return {}, {} - return self.runInteraction( + return self.db.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @@ -302,7 +302,7 @@ class AccountDataStore(AccountDataWorkerStore): # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self.simple_upsert( + yield self.db.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -348,7 +348,7 @@ class AccountDataStore(AccountDataWorkerStore): # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self.simple_upsert( + yield self.db.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, @@ -388,4 +388,4 @@ class AccountDataStore(AccountDataWorkerStore): ) txn.execute(update_max_id_sql, (next_id, next_id)) - return self.runInteraction("update_account_data_max_stream_id", _update) + return self.db.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 6b82fd392a..6b2e12719c 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -133,7 +133,7 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves to a list of ApplicationServices, which may be empty. """ - results = yield self.simple_select_list( + results = yield self.db.simple_select_list( "application_services_state", dict(state=state), ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -155,7 +155,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves to ApplicationServiceState. """ - result = yield self.simple_select_one( + result = yield self.db.simple_select_one( "application_services_state", dict(as_id=service.id), ["state"], @@ -175,7 +175,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves when the state was set successfully. """ - return self.simple_upsert( + return self.db.simple_upsert( "application_services_state", dict(as_id=service.id), dict(state=state) ) @@ -216,7 +216,7 @@ class ApplicationServiceTransactionWorkerStore( ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.runInteraction("create_appservice_txn", _create_appservice_txn) + return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) def complete_appservice_txn(self, txn_id, service): """Completes an application service transaction. @@ -249,7 +249,7 @@ class ApplicationServiceTransactionWorkerStore( ) # Set current txn_id for AS to 'txn_id' - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, "application_services_state", dict(as_id=service.id), @@ -257,11 +257,13 @@ class ApplicationServiceTransactionWorkerStore( ) # Delete txn - self.simple_delete_txn( + self.db.simple_delete_txn( txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) ) - return self.runInteraction("complete_appservice_txn", _complete_appservice_txn) + return self.db.runInteraction( + "complete_appservice_txn", _complete_appservice_txn + ) @defer.inlineCallbacks def get_oldest_unsent_txn(self, service): @@ -283,7 +285,7 @@ class ApplicationServiceTransactionWorkerStore( " ORDER BY txn_id ASC LIMIT 1", (service.id,), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return None @@ -291,7 +293,7 @@ class ApplicationServiceTransactionWorkerStore( return entry - entry = yield self.runInteraction( + entry = yield self.db.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) @@ -321,7 +323,7 @@ class ApplicationServiceTransactionWorkerStore( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.runInteraction( + return self.db.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) @@ -350,7 +352,7 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.runInteraction( + upper_bound, event_ids = yield self.db.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index de3256049d..54ed8574c4 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -95,7 +95,7 @@ class CacheInvalidationStore(SQLBaseStore): txn.call_after(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="cache_invalidation_stream", values={ @@ -122,7 +122,9 @@ class CacheInvalidationStore(SQLBaseStore): txn.execute(sql, (last_id, limit)) return txn.fetchall() - return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) + return self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) def get_cache_stream_token(self): if self._cache_id_gen: diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 66522a04b7..6f2a720b97 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -91,7 +91,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() - yield self.runWithConnection(f) + yield self.db.runWithConnection(f) yield self._end_background_update("user_ips_drop_nonunique_index") return 1 @@ -106,7 +106,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): def user_ips_analyze(txn): txn.execute("ANALYZE user_ips") - yield self.runInteraction("user_ips_analyze", user_ips_analyze) + yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) yield self._end_background_update("user_ips_analyze") @@ -140,7 +140,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): return None # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.runInteraction( + end_last_seen = yield self.db.runInteraction( "user_ips_dups_get_last_seen", get_last_seen ) @@ -275,7 +275,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) - yield self.runInteraction("user_ips_dups_remove", remove) + yield self.db.runInteraction("user_ips_dups_remove", remove) if last: yield self._end_background_update("user_ips_remove_dupes") @@ -352,7 +352,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): return len(rows) - updated = yield self.runInteraction( + updated = yield self.db.runInteraction( "_devices_last_seen_update", _devices_last_seen_update_txn ) @@ -417,12 +417,12 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): to_update = self._batch_row_update self._batch_row_update = {} - return self.runInteraction( + return self.db.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) def _update_client_ips_batch_txn(self, txn, to_update): - if "user_ips" in self._unsafe_to_upsert_tables or ( + if "user_ips" in self.db._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): self.database_engine.lock_table(txn, "user_ips") @@ -431,7 +431,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="user_ips", keyvalues={ @@ -450,7 +450,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -483,7 +483,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - res = yield self.simple_select_list( + res = yield self.db.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -516,7 +516,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], @@ -577,4 +577,4 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _prune_old_user_ips_txn(txn): txn.execute(sql, (timestamp,)) - await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) + await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 206d39134d..440793ad49 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.runInteraction( + return self.db.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn ) @@ -109,7 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - count = yield self.runInteraction( + count = yield self.db.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) @@ -178,7 +178,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.runInteraction( + return self.db.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @@ -203,7 +203,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.runInteraction( + return self.db.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) @@ -232,7 +232,7 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() - yield self.runWithConnection(reindex_txn) + yield self.db.runWithConnection(reindex_txn) yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) @@ -294,7 +294,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.runInteraction( + yield self.db.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): @@ -314,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. - already_inserted = self.simple_select_one_txn( + already_inserted = self.db.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, @@ -326,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Add an entry for this message_id so that we know we've processed # it. - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="device_federation_inbox", values={ @@ -344,7 +344,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.runInteraction( + yield self.db.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, @@ -465,6 +465,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) return rows - return self.runInteraction( + return self.db.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 727c582121..d98511ddd4 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore): Raises: StoreError: if the device is not found """ - return self.simple_select_one( + return self.db.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore): containing "device_id", "user_id" and "display_name" for each device. """ - devices = yield self.simple_select_list( + devices = yield self.db.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -122,7 +122,7 @@ class DeviceWorkerStore(SQLBaseStore): # consider the device update to be too large, and simply skip the # stream_id; the rationale being that such a large device list update # is likely an error. - updates = yield self.runInteraction( + updates = yield self.db.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, @@ -283,7 +283,7 @@ class DeviceWorkerStore(SQLBaseStore): """ devices = ( - yield self.runInteraction( + yield self.db.runInteraction( "_get_e2e_device_keys_txn", self._get_e2e_device_keys_txn, query_map.keys(), @@ -340,12 +340,12 @@ class DeviceWorkerStore(SQLBaseStore): rows = txn.fetchall() return rows[0][0] - return self.runInteraction("get_last_device_update_for_remote_user", f) + return self.db.runInteraction("get_last_device_update_for_remote_user", f) def mark_as_sent_devices_by_remote(self, destination, stream_id): """Mark that updates have successfully been sent to the destination. """ - return self.runInteraction( + return self.db.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, @@ -399,7 +399,7 @@ class DeviceWorkerStore(SQLBaseStore): """ with self._device_list_id_gen.get_next() as stream_id: - yield self.runInteraction( + yield self.db.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, from_user_id, @@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore): from_user_id, stream_id, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "user_signature_stream", values={ @@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2, tree=True) def _get_cached_user_device(self, user_id, device_id): - content = yield self.simple_select_one_onecol( + content = yield self.db.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): - devices = yield self.simple_select_list( + devices = yield self.db.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -492,7 +492,7 @@ class DeviceWorkerStore(SQLBaseStore): Returns: (stream_id, devices) """ - return self.runInteraction( + return self.db.runInteraction( "get_devices_with_keys_by_user", self._get_devices_with_keys_by_user_txn, user_id, @@ -565,7 +565,7 @@ class DeviceWorkerStore(SQLBaseStore): return changes - return self.runInteraction( + return self.db.runInteraction( "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) @@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self.execute( + rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return set(user for row in rows for user in json.loads(row[0])) @@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id, destination """ - return self.execute( + return self.db.execute( "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -685,7 +685,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.close() - yield self.runWithConnection(f) + yield self.db.runWithConnection(f) yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) return 1 @@ -722,7 +722,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self.simple_insert( + inserted = yield self.db.simple_insert( "devices", values={ "user_id": user_id, @@ -736,7 +736,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self.simple_select_one_onecol( + hidden = yield self.db.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -771,7 +771,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self.simple_delete_one( + yield self.db.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -789,7 +789,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self.simple_delete_many( + yield self.db.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -818,7 +818,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updates["display_name"] = new_display_name if not updates: return defer.succeed(None) - return self.simple_update_one( + return self.db.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -829,7 +829,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - yield self.simple_delete( + yield self.db.simple_delete( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, desc="mark_remote_user_device_list_as_unsubscribed", @@ -853,7 +853,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: Deferred[None] """ - return self.runInteraction( + return self.db.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -866,7 +866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn, user_id, device_id, content, stream_id ): if content.get("deleted"): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -874,7 +874,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -890,7 +890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -914,7 +914,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: Deferred[None] """ - return self.runInteraction( + return self.db.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -923,11 +923,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[ @@ -946,7 +946,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -962,7 +962,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): (if any) should be poked. """ with self._device_list_id_gen.get_next() as stream_id: - yield self.runInteraction( + yield self.db.runInteraction( "add_device_change_to_streams", self._add_device_change_txn, user_id, @@ -995,7 +995,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [(user_id, device_id, stream_id) for device_id in device_ids], ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_stream", values=[ @@ -1006,7 +1006,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context = get_active_span_text_map() - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ @@ -1069,7 +1069,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return run_as_background_process( "prune_old_outbound_device_pokes", - self.runInteraction, + self.db.runInteraction, "_prune_old_outbound_device_pokes", _prune_txn, ) diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py index d332f8a409..c9e7de7d12 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore): Deferred: results in namedtuple with keys "room_id" and "servers" or None if no association can be found """ - room_id = yield self.simple_select_one_onecol( + room_id = yield self.db.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self.simple_select_onecol( + servers = yield self.db.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) def get_room_alias_creator(self, room_alias): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", @@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore): @cached(max_entries=5000) def get_aliases_for_room(self, room_id): - return self.simple_select_onecol( + return self.db.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", @@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore): """ def alias_txn(txn): - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "room_aliases", { @@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore): }, ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="room_alias_servers", values=[ @@ -117,7 +117,9 @@ class DirectoryStore(DirectoryWorkerStore): ) try: - ret = yield self.runInteraction("create_room_alias_association", alias_txn) + ret = yield self.db.runInteraction( + "create_room_alias_association", alias_txn + ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() @@ -126,7 +128,7 @@ class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def delete_room_alias(self, room_alias): - room_id = yield self.runInteraction( + room_id = yield self.db.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) @@ -168,6 +170,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.runInteraction( + return self.db.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index df89eda337..84594cf0a9 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self.simple_update_one( + yield self.db.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -89,7 +89,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): } ) - yield self.simple_insert_many( + yield self.db.simple_insert_many( table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @@ -125,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -170,7 +170,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key """ - return self.runInteraction( + return self.db.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, @@ -234,7 +234,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version (str): the version ID of the backup we're querying about """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", @@ -267,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self.simple_delete( + yield self.db.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -312,7 +312,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): # it isn't there. raise StoreError(404, "No row found") - result = self.simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, @@ -324,7 +324,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): result["etag"] = 0 return result - return self.runInteraction( + return self.db.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @@ -352,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): new_version = str(int(current_version) + 1) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ @@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): return new_version - return self.runInteraction( + return self.db.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @@ -391,7 +391,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues["etag"] = version_etag if updatevalues: - return self.simple_update( + return self.db.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, @@ -420,19 +420,19 @@ class EndToEndRoomKeyStore(SQLBaseStore): else: this_version = version - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_room_keys", keyvalues={"user_id": user_id, "version": this_version}, ) - return self.simple_update_one_txn( + return self.db.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.runInteraction( + return self.db.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 08bcdc4725..38cd0ca9b8 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -48,7 +48,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): if not query_list: return {} - results = yield self.runInteraction( + results = yield self.db.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, @@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) txn.execute(sql, query_params) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) result = {} for row in rows: @@ -143,7 +143,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) txn.execute(signature_sql, signature_query_params) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) # add each cross-signing signature to the correct device in the result dict. for row in rows: @@ -186,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): key_id) to json string for key """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -219,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", values=[ @@ -238,7 +238,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - yield self.runInteraction( + yield self.db.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @@ -261,7 +261,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): result[algorithm] = key_count return result - return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) + return self.db.runInteraction( + "count_e2e_one_time_keys", _count_e2e_one_time_keys + ) def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): """Returns a user's cross-signing key. @@ -322,7 +324,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Returns: dict of the key data or None if not found """ - return self.runInteraction( + return self.db.runInteraction( "get_e2e_cross_signing_key", self._get_e2e_cross_signing_key_txn, user_id, @@ -350,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id """ - return self.execute( + return self.db.execute( "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key ) @@ -367,7 +369,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): set_tag("time_now", time_now) set_tag("device_keys", device_keys) - old_key_json = self.simple_select_one_onecol_txn( + old_key_json = self.db.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -383,7 +385,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"Message": "Device key already stored."}) return False - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -392,7 +394,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"message": "Device keys stored."}) return True - return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) + return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" @@ -431,7 +433,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) return result - return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys) + return self.db.runInteraction( + "claim_e2e_one_time_keys", _claim_e2e_one_time_keys + ) def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device_txn(txn): @@ -442,12 +446,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "user_id": user_id, } ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -456,7 +460,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return self.runInteraction( + return self.db.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) @@ -492,7 +496,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "devices", values={ @@ -505,7 +509,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "e2e_cross_signing_keys", values={ @@ -524,7 +528,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.runInteraction( + return self.db.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, user_id, @@ -539,7 +543,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): user_id (str): the user who made the signatures signatures (iterable[SignatureListItem]): signatures to add """ - return self.simple_insert_many( + return self.db.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 051ac7a8cb..77e4353b59 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -58,7 +58,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: list of event_ids """ - return self.runInteraction( + return self.db.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given ) @@ -90,12 +90,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) def get_oldest_events_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id ) def get_oldest_events_with_depth_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -126,7 +126,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns Deferred[int] """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -140,7 +140,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return max(row["depth"] for row in rows) def _get_oldest_events_in_room_txn(self, txn, room_id): - return self.simple_select_onecol_txn( + return self.db.simple_select_onecol_txn( txn, table="event_backward_extremities", keyvalues={"room_id": room_id}, @@ -188,7 +188,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas where *hashes* is a map from algorithm to hash. """ - return self.runInteraction( + return self.db.runInteraction( "get_latest_event_ids_and_hashes_in_room", self._get_latest_event_ids_and_hashes_in_room, room_id, @@ -229,13 +229,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.runInteraction( + return self.db.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", @@ -266,12 +266,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ - return self.runInteraction( + return self.db.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) def _get_min_depth_interaction(self, txn, room_id): - min_depth = self.simple_select_one_onecol_txn( + min_depth = self.db.simple_select_one_onecol_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -337,7 +337,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.runInteraction( + return self.db.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) @@ -352,7 +352,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas limit (int) """ return ( - self.runInteraction( + self.db.runInteraction( "get_backfill_events", self._get_backfill_events, room_id, @@ -383,7 +383,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas queue = PriorityQueue() for event_id in event_list: - depth = self.simple_select_one_onecol_txn( + depth = self.db.simple_select_one_onecol_txn( txn, table="events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -415,7 +415,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.runInteraction( + ids = yield self.db.runInteraction( "get_missing_events", self._get_missing_events, room_id, @@ -468,7 +468,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: Deferred[list[str]] """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -508,7 +508,7 @@ class EventFederationStore(EventFederationWorkerStore): if min_depth and depth >= min_depth: return - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -520,7 +520,7 @@ class EventFederationStore(EventFederationWorkerStore): For the given event, update the event edges table and forward and backward extremities tables. """ - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_edges", values=[ @@ -604,13 +604,13 @@ class EventFederationStore(EventFederationWorkerStore): return run_as_background_process( "delete_old_forward_extrem_cache", - self.runInteraction, + self.db.runInteraction, "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, ) def clean_room_for_join(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) @@ -660,7 +660,7 @@ class EventFederationStore(EventFederationWorkerStore): return min_stream_id >= target_min_stream_id - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_AUTH_STATE_ONLY, delete_event_auth ) diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 0a37847cfd..725d0881dc 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -93,7 +93,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): - ret = yield self.runInteraction( + ret = yield self.db.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, room_id, @@ -177,7 +177,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = yield self.runInteraction("get_push_action_users_in_range", f) + ret = yield self.db.runInteraction("get_push_action_users_in_range", f) return ret @defer.inlineCallbacks @@ -229,7 +229,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.runInteraction( + after_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -257,7 +257,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.runInteraction( + no_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -329,7 +329,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.runInteraction( + after_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -357,7 +357,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.runInteraction( + no_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -407,7 +407,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.runInteraction( + return self.db.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) @@ -458,7 +458,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ), ) - return self.runInteraction( + return self.db.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) @@ -472,7 +472,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ try: - res = yield self.simple_delete( + res = yield self.db.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -489,7 +489,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _find_stream_orderings_for_times(self): return run_as_background_process( "event_push_action_stream_orderings", - self.runInteraction, + self.db.runInteraction, "_find_stream_orderings_for_times", self._find_stream_orderings_for_times_txn, ) @@ -525,7 +525,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): Deferred[int]: stream ordering of the first event received on/after the timestamp """ - return self.runInteraction( + return self.db.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, @@ -677,7 +677,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) for event, _ in events_and_contexts: - user_ids = self.simple_select_onecol_txn( + user_ids = self.db.simple_select_onecol_txn( txn, table="event_push_actions_staging", keyvalues={"event_id": event.event_id}, @@ -727,9 +727,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore): " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - push_actions = yield self.runInteraction("get_push_actions_for_user", f) + push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions @@ -748,7 +748,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): txn.execute(sql, (stream_ordering,)) return txn.fetchone() - result = yield self.runInteraction("get_time_of_last_push_action_before", f) + result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) return result[0] if result else None @defer.inlineCallbacks @@ -757,7 +757,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.runInteraction("get_latest_push_action_stream_ordering", f) + result = yield self.db.runInteraction( + "get_latest_push_action_stream_ordering", f + ) return result[0] or 0 def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): @@ -830,7 +832,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): while True: logger.info("Rotating notifications") - caught_up = yield self.runInteraction( + caught_up = yield self.db.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: @@ -844,7 +846,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): the archiving process has caught up or not. """ - old_rotate_stream_ordering = self.simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -880,7 +882,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return caught_up def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self.simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -912,7 +914,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the # existing table. - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_push_summary", values=[ diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 98ae69e996..01ec9ec397 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -143,7 +143,7 @@ class EventsStore( ) return txn.fetchall() - res = yield self.runInteraction("read_forward_extremities", fetch) + res = yield self.db.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) @_retry_on_integrity_error @@ -208,7 +208,7 @@ class EventsStore( for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream - yield self.runInteraction( + yield self.db.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, @@ -281,7 +281,7 @@ class EventsStore( results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( + yield self.db.runInteraction( "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk ) @@ -345,7 +345,7 @@ class EventsStore( existing_prevs.add(prev_event_id) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( + yield self.db.runInteraction( "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk ) @@ -432,7 +432,7 @@ class EventsStore( # event's auth chain, but its easier for now just to store them (and # it doesn't take much storage compared to storing the entire event # anyway). - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -580,12 +580,12 @@ class EventsStore( self, txn, new_forward_extremities, max_stream_order ): for room_id, new_extrem in iteritems(new_forward_extremities): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_forward_extremities", values=[ @@ -598,7 +598,7 @@ class EventsStore( # new stream_ordering to new forward extremeties in the room. # This allows us to later efficiently look up the forward extremeties # for a room before a given stream_ordering - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="stream_ordering_to_exterm", values=[ @@ -722,7 +722,7 @@ class EventsStore( # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering state_group_id = context.state_group - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="ex_outlier_stream", values={ @@ -794,7 +794,7 @@ class EventsStore( d.pop("redacted_because", None) return d - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_json", values=[ @@ -811,7 +811,7 @@ class EventsStore( ], ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="events", values=[ @@ -841,7 +841,7 @@ class EventsStore( # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - self.simple_update_txn( + self.db.simple_update_txn( txn, table="redactions", keyvalues={"redacts": event.event_id}, @@ -983,7 +983,7 @@ class EventsStore( state_values.append(vals) - self.simple_insert_many_txn(txn, table="state_events", values=state_values) + self.db.simple_insert_many_txn(txn, table="state_events", values=state_values) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1014,7 +1014,7 @@ class EventsStore( ) txn.execute(sql + clause, args) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: @@ -1032,7 +1032,7 @@ class EventsStore( # invalidate the cache for the redacted event txn.call_after(self._invalidate_get_event_cache, event.redacts) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="redactions", values={ @@ -1077,7 +1077,9 @@ class EventsStore( LIMIT ? """ - rows = yield self.execute("_censor_redactions_fetch", None, sql, before_ts, 100) + rows = yield self.db.execute( + "_censor_redactions_fetch", None, sql, before_ts, 100 + ) updates = [] @@ -1109,14 +1111,14 @@ class EventsStore( if pruned_json: self._censor_event_txn(txn, event_id, pruned_json) - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="redactions", keyvalues={"event_id": redaction_id}, updatevalues={"have_censored": True}, ) - yield self.runInteraction("_update_censor_txn", _update_censor_txn) + yield self.db.runInteraction("_update_censor_txn", _update_censor_txn) def _censor_event_txn(self, txn, event_id, pruned_json): """Censor an event by replacing its JSON in the event_json table with the @@ -1127,7 +1129,7 @@ class EventsStore( event_id (str): The ID of the event to censor. pruned_json (str): The pruned JSON """ - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="event_json", keyvalues={"event_id": event_id}, @@ -1153,7 +1155,7 @@ class EventsStore( (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_messages", _count_messages) + ret = yield self.db.runInteraction("count_messages", _count_messages) return ret @defer.inlineCallbacks @@ -1174,7 +1176,7 @@ class EventsStore( (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) + ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) return ret @defer.inlineCallbacks @@ -1189,7 +1191,7 @@ class EventsStore( (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_daily_active_rooms", _count) + ret = yield self.db.runInteraction("count_daily_active_rooms", _count) return ret def get_current_backfill_token(self): @@ -1241,7 +1243,7 @@ class EventsStore( return new_event_updates - return self.runInteraction( + return self.db.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) @@ -1286,7 +1288,7 @@ class EventsStore( return new_event_updates - return self.runInteraction( + return self.db.runInteraction( "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) @@ -1379,7 +1381,7 @@ class EventsStore( backward_ex_outliers, ) - return self.runInteraction("get_all_new_events", get_all_new_events_txn) + return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) def purge_history(self, room_id, token, delete_local_events): """Deletes room history before a certain point @@ -1399,7 +1401,7 @@ class EventsStore( deleted events. """ - return self.runInteraction( + return self.db.runInteraction( "purge_history", self._purge_history_txn, room_id, @@ -1647,7 +1649,7 @@ class EventsStore( Deferred[List[int]]: The list of state groups to delete. """ - return self.runInteraction("purge_room", self._purge_room_txn, room_id) + return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) def _purge_room_txn(self, txn, room_id): # First we fetch all the state groups that should be deleted, before @@ -1766,7 +1768,7 @@ class EventsStore( to delete. """ - return self.runInteraction( + return self.db.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, @@ -1778,7 +1780,7 @@ class EventsStore( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) - rows = self.simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="state_group_edges", column="prev_state_group", @@ -1805,15 +1807,15 @@ class EventsStore( curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": sg} ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1850,7 +1852,7 @@ class EventsStore( state group. """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, @@ -1869,7 +1871,7 @@ class EventsStore( state_groups_to_delete (list[int]): State groups to delete """ - return self.runInteraction( + return self.db.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, @@ -1880,7 +1882,7 @@ class EventsStore( # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) - self.simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_groups_state", column="state_group", @@ -1891,7 +1893,7 @@ class EventsStore( # ... and the state group edges logger.info("[purge] removing %s from state_group_edges", room_id) - self.simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_group_edges", column="state_group", @@ -1902,7 +1904,7 @@ class EventsStore( # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - self.simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_groups", column="id", @@ -1919,7 +1921,7 @@ class EventsStore( @cachedInlineCallbacks(max_entries=5000) def _get_event_ordering(self, event_id): - res = yield self.simple_select_one( + res = yield self.db.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, @@ -1942,7 +1944,7 @@ class EventsStore( txn.execute(sql, (from_token, to_token, limit)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1960,7 +1962,7 @@ class EventsStore( room_id (str): The ID of the room the event was sent to. topological_ordering (int): The position of the event in the room's topology. """ - return self.simple_insert_many_txn( + return self.db.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -1982,7 +1984,7 @@ class EventsStore( event_id (str): The event ID the expiry timestamp is associated with. expiry_ts (int): The timestamp at which to expire (delete) the event. """ - return self.simple_insert_txn( + return self.db.simple_insert_txn( txn=txn, table="event_expiry", values={"event_id": event_id, "expiry_ts": expiry_ts}, @@ -2031,7 +2033,7 @@ class EventsStore( txn, "_get_event_cache", (event.event_id,) ) - yield self.runInteraction("delete_expired_event", delete_expired_event_txn) + yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) def _delete_event_expiry_txn(self, txn, event_id): """Delete the expiry timestamp associated with an event ID without deleting the @@ -2041,7 +2043,7 @@ class EventsStore( txn (LoggingTransaction): The transaction to use to perform the deletion. event_id (str): The event ID to delete the associated expiry timestamp of. """ - return self.simple_delete_txn( + return self.db.simple_delete_txn( txn=txn, table="event_expiry", keyvalues={"event_id": event_id} ) @@ -2065,7 +2067,7 @@ class EventsStore( return txn.fetchone() - return self.runInteraction( + return self.db.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 37dfc8c871..365e966956 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -151,7 +151,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) @@ -189,7 +189,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self.simple_select_many_txn( + ev_rows = self.db.simple_select_many_txn( txn, table="event_json", column="event_id", @@ -228,7 +228,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(rows_to_update) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) @@ -366,7 +366,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): to_delete.intersection_update(original_set) - deleted = self.simple_delete_many_txn( + deleted = self.db.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -382,7 +382,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if deleted: # We now need to invalidate the caches of these rooms - rows = self.simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="events", column="event_id", @@ -396,7 +396,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self.simple_delete_many_txn( + self.db.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", @@ -406,7 +406,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(original_set) - num_handled = yield self.runInteraction( + num_handled = yield self.db.runInteraction( "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) @@ -416,7 +416,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") - yield self.runInteraction( + yield self.db.runInteraction( "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) @@ -470,7 +470,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(rows) - count = yield self.runInteraction( + count = yield self.db.runInteraction( "_redactions_received_ts", _redactions_received_ts_txn ) @@ -501,7 +501,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): txn.execute("DROP INDEX redactions_censored_redacts") - yield self.runInteraction( + yield self.db.runInteraction( "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) @@ -533,7 +533,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): try: event_json = json.loads(event_json_raw) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -565,7 +565,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return nbrows - num_rows = yield self.runInteraction( + num_rows = yield self.db.runInteraction( desc="event_store_labels", func=_event_store_labels_txn ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 6a08a746b6..e041fc5eac 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -78,7 +78,7 @@ class EventsWorkerStore(SQLBaseStore): Deferred[int|None]: Timestamp in milliseconds, or None for events that were persisted before received_ts was implemented. """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", @@ -117,7 +117,7 @@ class EventsWorkerStore(SQLBaseStore): return ts - return self.runInteraction( + return self.db.runInteraction( "get_approximate_received_ts", _get_approximate_received_ts_txn ) @@ -452,7 +452,7 @@ class EventsWorkerStore(SQLBaseStore): event_id for events, _ in event_list for event_id in events ) - row_dict = self.new_transaction( + row_dict = self.db.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) @@ -584,7 +584,7 @@ class EventsWorkerStore(SQLBaseStore): if should_start: run_as_background_process( - "fetch_events", self.runWithConnection, self._do_fetch + "fetch_events", self.db.runWithConnection, self._do_fetch ) logger.debug("Loading %d events: %s", len(events), events) @@ -745,7 +745,7 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", @@ -780,7 +780,9 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) + yield self.db.runInteraction( + "have_seen_events", have_seen_events_txn, chunk + ) return results def _get_total_state_event_counts_txn(self, txn, room_id): @@ -807,7 +809,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred[int] """ - return self.runInteraction( + return self.db.runInteraction( "get_total_state_event_counts", self._get_total_state_event_counts_txn, room_id, @@ -832,7 +834,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred[int] """ - return self.runInteraction( + return self.db.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py index 17ef7b9354..342d6622a4 100644 --- a/synapse/storage/data_stores/main/filtering.py +++ b/synapse/storage/data_stores/main/filtering.py @@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self.simple_select_one_onecol( + def_json = yield self.db.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", @@ -71,4 +71,4 @@ class FilteringStore(SQLBaseStore): return filter_id - return self.runInteraction("add_user_filter", _do_txn) + return self.db.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 9e1d12bcb7..7f5e8dce66 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore): * "invite" * "open" """ - return self.simple_update_one( + return self.db.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore): ) def get_group(self, group_id): - return self.simple_select_one( + return self.db.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self.simple_select_list( + return self.db.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), @@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore): def get_invited_users_in_group(self, group_id): # TODO: Pagination - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self.simple_select_list( + return self.db.simple_select_list( table="group_rooms", keyvalues=keyvalues, retcols=("room_id", "is_public"), @@ -153,10 +153,12 @@ class GroupServerStore(SQLBaseStore): return rooms, categories - return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn) + return self.db.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.runInteraction( + return self.db.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -180,7 +182,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the room first. Otherwise, the room gets added to the end. """ - room_in_group = self.simple_select_one_onecol_txn( + room_in_group = self.db.simple_select_one_onecol_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -193,7 +195,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID else: - cat_exists = self.simple_select_one_onecol_txn( + cat_exists = self.db.simple_select_one_onecol_txn( txn, table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -204,7 +206,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Category doesn't exist") # TODO: Check category is part of summary already - cat_exists = self.simple_select_one_onecol_txn( + cat_exists = self.db.simple_select_one_onecol_txn( txn, table="group_summary_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -224,7 +226,7 @@ class GroupServerStore(SQLBaseStore): (group_id, category_id, group_id, category_id), ) - existing = self.simple_select_one_txn( + existing = self.db.simple_select_one_txn( txn, table="group_summary_rooms", keyvalues={ @@ -257,7 +259,7 @@ class GroupServerStore(SQLBaseStore): to_update["room_order"] = order if is_public is not None: to_update["is_public"] = is_public - self.simple_update_txn( + self.db.simple_update_txn( txn, table="group_summary_rooms", keyvalues={ @@ -271,7 +273,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_summary_rooms", values={ @@ -287,7 +289,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self.simple_delete( + return self.db.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -299,7 +301,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_categories(self, group_id): - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -316,7 +318,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_category(self, group_id, category_id): - category = yield self.simple_select_one( + category = yield self.db.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -343,7 +345,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self.simple_upsert( + return self.db.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -352,7 +354,7 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_category(self, group_id, category_id): - return self.simple_delete( + return self.db.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -360,7 +362,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_roles(self, group_id): - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -377,7 +379,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_role(self, group_id, role_id): - role = yield self.simple_select_one( + role = yield self.db.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -404,7 +406,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self.simple_upsert( + return self.db.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -413,14 +415,14 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_role(self, group_id, role_id): - return self.simple_delete( + return self.db.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", ) def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.runInteraction( + return self.db.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -444,7 +446,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the user first. Otherwise, the user gets added to the end. """ - user_in_group = self.simple_select_one_onecol_txn( + user_in_group = self.db.simple_select_one_onecol_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -457,7 +459,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID else: - role_exists = self.simple_select_one_onecol_txn( + role_exists = self.db.simple_select_one_onecol_txn( txn, table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -468,7 +470,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Role doesn't exist") # TODO: Check role is part of the summary already - role_exists = self.simple_select_one_onecol_txn( + role_exists = self.db.simple_select_one_onecol_txn( txn, table="group_summary_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -488,7 +490,7 @@ class GroupServerStore(SQLBaseStore): (group_id, role_id, group_id, role_id), ) - existing = self.simple_select_one_txn( + existing = self.db.simple_select_one_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, @@ -517,7 +519,7 @@ class GroupServerStore(SQLBaseStore): to_update["user_order"] = order if is_public is not None: to_update["is_public"] = is_public - self.simple_update_txn( + self.db.simple_update_txn( txn, table="group_summary_users", keyvalues={ @@ -531,7 +533,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_summary_users", values={ @@ -547,7 +549,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID - return self.simple_delete( + return self.db.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -561,7 +563,7 @@ class GroupServerStore(SQLBaseStore): Deferred[list[str]]: A twisted.Deferred containing a list of group ids containing this room """ - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -625,12 +627,12 @@ class GroupServerStore(SQLBaseStore): return users, roles - return self.runInteraction( + return self.db.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) def is_user_in_group(self, user_id, group_id): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -639,7 +641,7 @@ class GroupServerStore(SQLBaseStore): ).addCallback(lambda r: bool(r)) def is_user_admin_in_group(self, group_id, user_id): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -650,7 +652,7 @@ class GroupServerStore(SQLBaseStore): def add_group_invite(self, group_id, user_id): """Record that the group server has invited a user """ - return self.simple_insert( + return self.db.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -659,7 +661,7 @@ class GroupServerStore(SQLBaseStore): def is_user_invited_to_local_group(self, group_id, user_id): """Has the group server invited a user? """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -682,7 +684,7 @@ class GroupServerStore(SQLBaseStore): """ def _get_users_membership_in_group_txn(txn): - row = self.simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -697,7 +699,7 @@ class GroupServerStore(SQLBaseStore): "is_privileged": row["is_admin"], } - row = self.simple_select_one_onecol_txn( + row = self.db.simple_select_one_onecol_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -710,7 +712,7 @@ class GroupServerStore(SQLBaseStore): return {} - return self.runInteraction( + return self.db.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) @@ -738,7 +740,7 @@ class GroupServerStore(SQLBaseStore): """ def _add_user_to_group_txn(txn): - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_users", values={ @@ -749,14 +751,14 @@ class GroupServerStore(SQLBaseStore): }, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -766,7 +768,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -777,49 +779,49 @@ class GroupServerStore(SQLBaseStore): }, ) - return self.runInteraction("add_user_to_group", _add_user_to_group_txn) + return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn) def remove_user_from_group(self, group_id, user_id): def _remove_user_from_group_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) def add_room_to_group(self, group_id, room_id, is_public): - return self.simple_insert( + return self.db.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self.simple_update( + return self.db.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -828,26 +830,26 @@ class GroupServerStore(SQLBaseStore): def remove_room_from_group(self, group_id, room_id): def _remove_room_from_group_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_summary_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) def get_publicised_groups_for_user(self, user_id): """Get all groups a user is publicising """ - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -857,7 +859,7 @@ class GroupServerStore(SQLBaseStore): def update_group_publicity(self, group_id, user_id, publicise): """Update whether the user is publicising their membership of the group """ - return self.simple_update_one( + return self.db.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -893,12 +895,12 @@ class GroupServerStore(SQLBaseStore): def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_group_membership", values={ @@ -911,7 +913,7 @@ class GroupServerStore(SQLBaseStore): }, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_group_updates", values={ @@ -930,7 +932,7 @@ class GroupServerStore(SQLBaseStore): if membership == "join": if local_attestation: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -940,7 +942,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -951,12 +953,12 @@ class GroupServerStore(SQLBaseStore): }, ) else: - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -965,7 +967,7 @@ class GroupServerStore(SQLBaseStore): return next_id with self._group_updates_id_gen.get_next() as next_id: - res = yield self.runInteraction( + res = yield self.db.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, @@ -976,7 +978,7 @@ class GroupServerStore(SQLBaseStore): def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description ): - yield self.simple_insert( + yield self.db.simple_insert( table="groups", values={ "group_id": group_id, @@ -991,7 +993,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def update_group_profile(self, group_id, profile): - yield self.simple_update_one( + yield self.db.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, @@ -1008,16 +1010,16 @@ class GroupServerStore(SQLBaseStore): WHERE valid_until_ms <= ? """ txn.execute(sql, (valid_until_ms,)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) def update_attestation_renewal(self, group_id, user_id, attestation): """Update an attestation that we have renewed """ - return self.simple_update_one( + return self.db.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, @@ -1027,7 +1029,7 @@ class GroupServerStore(SQLBaseStore): def update_remote_attestion(self, group_id, user_id, attestation): """Update an attestation that a remote has renewed """ - return self.simple_update_one( + return self.db.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ @@ -1046,7 +1048,7 @@ class GroupServerStore(SQLBaseStore): group_id (str) user_id (str) """ - return self.simple_delete( + return self.db.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", @@ -1057,7 +1059,7 @@ class GroupServerStore(SQLBaseStore): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self.simple_select_one( + row = yield self.db.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -1072,7 +1074,7 @@ class GroupServerStore(SQLBaseStore): return None def get_joined_groups(self, user_id): - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -1099,7 +1101,7 @@ class GroupServerStore(SQLBaseStore): for row in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) @@ -1129,7 +1131,7 @@ class GroupServerStore(SQLBaseStore): for group_id, membership, gtype, content_json in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) @@ -1154,7 +1156,7 @@ class GroupServerStore(SQLBaseStore): for stream_id, group_id, user_id, gtype, content_json in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_all_groups_changes", _get_all_groups_changes_txn ) @@ -1188,8 +1190,8 @@ class GroupServerStore(SQLBaseStore): ] for table in tables: - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table=table, keyvalues={"group_id": group_id} ) - return self.runInteraction("delete_group", _delete_group_txn) + return self.db.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index c7150432b3..6b12f5a75f 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore): _get_keys(txn, batch) return keys - return self.runInteraction("get_server_verify_keys", _txn) + return self.db.runInteraction("get_server_verify_keys", _txn) def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): """Stores NACL verification keys for remote servers. @@ -127,9 +127,9 @@ class KeyStore(SQLBaseStore): f((i,)) return res - return self.runInteraction( + return self.db.runInteraction( "store_server_verify_keys", - self.simple_upsert_many_txn, + self.db.simple_upsert_many_txn, table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore): ts_valid_until_ms (int): The time when this json stops being valid. key_json (bytes): The encoded JSON. """ - return self.simple_upsert( + return self.db.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, @@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore): keyvalues["key_id"] = key_id if from_server is not None: keyvalues["from_server"] = from_server - rows = self.simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "server_keys_json", keyvalues=keyvalues, @@ -211,4 +211,4 @@ class KeyStore(SQLBaseStore): results[(server_name, key_id, from_server)] = rows return results - return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) + return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 0cb9446f96..ea02497784 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -39,7 +39,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): Returns: None if the media_id doesn't exist. """ - return self.simple_select_one( + return self.db.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -64,7 +64,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id, url_cache=None, ): - return self.simple_insert( + return self.db.simple_insert( "local_media_repository", { "media_id": media_id, @@ -124,12 +124,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) ) - return self.runInteraction("get_url_cache", get_url_cache_txn) + return self.db.runInteraction("get_url_cache", get_url_cache_txn) def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self.simple_insert( + return self.db.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -144,7 +144,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_local_media_thumbnails(self, media_id): - return self.simple_select_list( + return self.db.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -166,7 +166,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.simple_insert( + return self.db.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -180,7 +180,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_cached_remote_media(self, origin, media_id): - return self.simple_select_one( + return self.db.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -205,7 +205,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name, filesystem_id, ): - return self.simple_insert( + return self.db.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -250,10 +250,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - return self.runInteraction("update_cached_last_access_time", update_cache_txn) + return self.db.runInteraction( + "update_cached_last_access_time", update_cache_txn + ) def get_remote_media_thumbnails(self, origin, media_id): - return self.simple_select_list( + return self.db.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -278,7 +280,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.simple_insert( + return self.db.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, @@ -300,24 +302,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE last_access_ts < ?" ) - return self.execute( - "get_remote_media_before", self.cursor_to_dict, sql, before_ts + return self.db.execute( + "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts ) def delete_remote_media(self, media_origin, media_id): def delete_remote_media_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, "remote_media_cache", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, "remote_media_cache_thumbnails", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - return self.runInteraction("delete_remote_media", delete_remote_media_txn) + return self.db.runInteraction("delete_remote_media", delete_remote_media_txn) def get_expired_url_cache(self, now_ts): sql = ( @@ -331,7 +333,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute(sql, (now_ts,)) return [row[0] for row in txn] - return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + return self.db.runInteraction( + "get_expired_url_cache", _get_expired_url_cache_txn + ) def delete_url_cache(self, media_ids): if len(media_ids) == 0: @@ -342,7 +346,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) def get_url_cache_media_before(self, before_ts): sql = ( @@ -356,7 +360,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute(sql, (before_ts,)) return [row[0] for row in txn] - return self.runInteraction( + return self.db.runInteraction( "get_url_cache_media_before", _get_url_cache_media_before_txn ) @@ -373,6 +377,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.runInteraction( + return self.db.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index b8fc28f97b..34bf3a1880 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -32,7 +32,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number - self.new_transaction( + self.db.new_transaction( dbconn, "initialise_mau_threepids", [], @@ -146,7 +146,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): txn.execute(sql, query_args) reserved_users = yield self.get_registered_reserved_users() - yield self.runInteraction( + yield self.db.runInteraction( "reap_monthly_active_users", _reap_users, reserved_users ) # It seems poor to invalidate the whole cache, Postgres supports @@ -174,7 +174,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): (count,) = txn.fetchone() return count - return self.runInteraction("count_users", _count_users) + return self.db.runInteraction("count_users", _count_users) @defer.inlineCallbacks def get_registered_reserved_users(self): @@ -217,7 +217,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): if is_support: return - yield self.runInteraction( + yield self.db.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) @@ -261,7 +261,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self.simple_upsert_txn( + is_insert = self.db.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -281,7 +281,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py index 650e49750e..cc21437e92 100644 --- a/synapse/storage/data_stores/main/openid.py +++ b/synapse/storage/data_stores/main/openid.py @@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self.simple_insert( + return self.db.simple_insert( table="open_id_tokens", values={ "token": token, @@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn) + return self.db.runInteraction( + "get_user_id_for_token", get_user_id_for_token_txn + ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index a5e121efd1..a2c83e0867 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -29,7 +29,7 @@ class PresenceStore(SQLBaseStore): ) with stream_ordering_manager as stream_orderings: - yield self.runInteraction( + yield self.db.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, @@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore): txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) # Actually insert new rows - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="presence_stream", values=[ @@ -88,7 +88,7 @@ class PresenceStore(SQLBaseStore): txn.execute(sql, (last_id, current_id)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_presence_updates", get_all_presence_updates_txn ) @@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore): inlineCallbacks=True, ) def get_presence_for_users(self, user_ids): - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, @@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore): return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): - return self.simple_insert( + return self.db.simple_insert( table="presence_allow_inbound", values={ "observed_user_id": observed_localpart, @@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore): ) def disallow_presence_visible(self, observed_localpart, observer_userid): - return self.simple_delete_one( + return self.db.simple_delete_one( table="presence_allow_inbound", keyvalues={ "observed_user_id": observed_localpart, diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index c8b5b60301..2b52cf9c1a 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_profileinfo(self, user_localpart): try: - profile = yield self.simple_select_one( + profile = yield self.db.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_displayname(self, user_localpart): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", @@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_avatar_url(self, user_localpart): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", @@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_from_remote_profile_cache(self, user_id): - return self.simple_select_one( + return self.db.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore): ) def create_profile(self, user_localpart): - return self.simple_insert( + return self.db.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) def set_profile_displayname(self, user_localpart, new_displayname): - return self.simple_update_one( + return self.db.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, @@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self.simple_update_one( + return self.db.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore): This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self.simple_upsert( + return self.db.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore): ) def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self.simple_update( + return self.db.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore): """ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self.simple_delete( + yield self.db.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -144,9 +144,9 @@ class ProfileStore(ProfileWorkerStore): txn.execute(sql, (last_checked,)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) @@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 75bd499bcd..de682cc63a 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -75,7 +75,7 @@ class PushRulesWorkerStore( def __init__(self, db_conn, hs): super(PushRulesWorkerStore, self).__init__(db_conn, hs) - push_rules_prefill, push_rules_id = self.get_cache_dict( + push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", @@ -100,7 +100,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -124,7 +124,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): - results = yield self.simple_select_list( + results = yield self.db.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -146,7 +146,7 @@ class PushRulesWorkerStore( (count,) = txn.fetchone() return bool(count) - return self.runInteraction( + return self.db.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) @@ -162,7 +162,7 @@ class PushRulesWorkerStore( results = {user_id: [] for user_id in user_ids} - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -320,7 +320,7 @@ class PushRulesWorkerStore( results = {user_id: {} for user_id in user_ids} - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, @@ -350,7 +350,7 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids if before or after: - yield self.runInteraction( + yield self.db.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, stream_id, @@ -364,7 +364,7 @@ class PushRuleStore(PushRulesWorkerStore): after, ) else: - yield self.runInteraction( + yield self.db.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, stream_id, @@ -395,7 +395,7 @@ class PushRuleStore(PushRulesWorkerStore): relative_to_rule = before or after - res = self.simple_select_one_txn( + res = self.db.simple_select_one_txn( txn, table="push_rules", keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, @@ -518,7 +518,7 @@ class PushRuleStore(PushRulesWorkerStore): # We didn't update a row with the given rule_id so insert one push_rule_id = self._push_rule_id_gen.get_next() - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="push_rules", values={ @@ -561,7 +561,7 @@ class PushRuleStore(PushRulesWorkerStore): """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self.simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "delete_push_rule", delete_push_rule_txn, stream_id, @@ -582,7 +582,7 @@ class PushRuleStore(PushRulesWorkerStore): def set_push_rule_enabled(self, user_id, rule_id, enabled): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, stream_id, @@ -596,7 +596,7 @@ class PushRuleStore(PushRulesWorkerStore): self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled ): new_id = self._push_rules_enable_id_gen.get_next() - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}, @@ -636,7 +636,7 @@ class PushRuleStore(PushRulesWorkerStore): update_stream=False, ) else: - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}, @@ -655,7 +655,7 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "set_push_rule_actions", set_push_rule_actions_txn, stream_id, @@ -675,7 +675,7 @@ class PushRuleStore(PushRulesWorkerStore): if data is not None: values.update(data) - self.simple_insert_txn(txn, "push_rules_stream", values=values) + self.db.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) @@ -699,7 +699,7 @@ class PushRuleStore(PushRulesWorkerStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index d5a169872b..f07309ef09 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_has_pusher(self, user_id): - ret = yield self.simple_select_one_onecol( + ret = yield self.db.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_pushers_by(self, keyvalues): - ret = yield self.simple_select_list( + ret = yield self.db.simple_select_list( "pushers", keyvalues, [ @@ -100,11 +100,11 @@ class PusherWorkerStore(SQLBaseStore): def get_all_pushers(self): def get_pushers(txn): txn.execute("SELECT * FROM pushers") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - rows = yield self.runInteraction("get_all_pushers", get_pushers) + rows = yield self.db.runInteraction("get_all_pushers", get_pushers) return rows def get_all_updated_pushers(self, last_id, current_id, limit): @@ -134,7 +134,7 @@ class PusherWorkerStore(SQLBaseStore): return updated, deleted - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_pushers", get_all_updated_pushers_txn ) @@ -177,7 +177,7 @@ class PusherWorkerStore(SQLBaseStore): return results - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) @@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_if_users_have_pushers(self, user_ids): - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -230,7 +230,7 @@ class PusherStore(PusherWorkerStore): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry - yield self.simple_upsert( + yield self.db.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -255,7 +255,7 @@ class PusherStore(PusherWorkerStore): if user_has_pusher is not True: # invalidate, since we the user might not have had a pusher before - yield self.runInteraction( + yield self.db.runInteraction( "add_pusher", self._invalidate_cache_and_stream, self.get_if_user_has_pusher, @@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore): txn, self.get_if_user_has_pusher, (user_id,) ) - self.simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore): # it's possible for us to end up with duplicate rows for # (app_id, pushkey, user_id) at different stream_ids, but that # doesn't really matter. - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="deleted_pushers", values={ @@ -290,13 +290,13 @@ class PusherStore(PusherWorkerStore): ) with self._pushers_id_gen.get_next() as stream_id: - yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id) + yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) @defer.inlineCallbacks def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ): - yield self.simple_update_one( + yield self.db.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, @@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore): Returns: Deferred[bool]: True if the pusher still exists; False if it has been deleted. """ - updated = yield self.simple_update( + updated = yield self.db.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self.simple_update( + yield self.db.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, @@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def get_throttle_params_by_room(self, pusher_id): - res = yield self.simple_select_list( + res = yield self.db.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -362,7 +362,7 @@ class PusherStore(PusherWorkerStore): def set_throttle_params(self, pusher_id, room_id, params): # no need to lock because `pusher_throttle` has a primary key on # (pusher, room_id) so simple_upsert will retry - yield self.simple_upsert( + yield self.db.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 380f388e30..ac2d45bd5c 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): - return self.simple_select_list( + return self.db.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), @@ -70,7 +70,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -84,7 +84,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -108,7 +108,7 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return txn.fetchall() - rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f) + rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f) return { row[0]: { "event_id": row[1], @@ -187,11 +187,11 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (room_id, to_key)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) return rows - rows = yield self.runInteraction("get_linearized_receipts_for_room", f) + rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] @@ -237,9 +237,11 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql + clause, [to_key] + list(args)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f) + txn_results = yield self.db.runInteraction( + "_get_linearized_receipts_for_rooms", f + ) results = {} for row in txn_results: @@ -282,7 +284,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return list(r[0:5] + (json.loads(r[5]),) for r in txn) - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) @@ -335,7 +337,7 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ - res = self.simple_select_one_txn( + res = self.db.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], @@ -388,7 +390,7 @@ class ReceiptsStore(ReceiptsWorkerStore): (user_id, room_id, receipt_type), ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="receipts_linearized", keyvalues={ @@ -398,7 +400,7 @@ class ReceiptsStore(ReceiptsWorkerStore): }, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="receipts_linearized", values={ @@ -453,13 +455,13 @@ class ReceiptsStore(ReceiptsWorkerStore): else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = yield self.runInteraction( + linearized_event_id = yield self.db.runInteraction( "insert_receipt_conv", graph_to_linear ) stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: - event_ts = yield self.runInteraction( + event_ts = yield self.db.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, @@ -488,7 +490,7 @@ class ReceiptsStore(ReceiptsWorkerStore): return stream_id, max_persisted_id def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.runInteraction( + return self.db.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -514,7 +516,7 @@ class ReceiptsStore(ReceiptsWorkerStore): self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="receipts_graph", keyvalues={ @@ -523,7 +525,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "user_id": user_id, }, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="receipts_graph", values={ diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index debc6706f5..8f9aa87ceb 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -45,7 +45,7 @@ class RegistrationWorkerStore(SQLBaseStore): @cached() def get_user_by_id(self, user_id): - return self.simple_select_one( + return self.db.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -94,7 +94,7 @@ class RegistrationWorkerStore(SQLBaseStore): including the keys `name`, `is_guest`, `device_id`, `token_id`, `valid_until_ms`. """ - return self.runInteraction( + return self.db.runInteraction( "get_user_by_access_token", self._query_for_auth, token ) @@ -109,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore): otherwise int representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", @@ -137,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ def set_account_validity_for_user_txn(txn): - self.simple_update_txn( + self.db.simple_update_txn( txn=txn, table="account_validity", keyvalues={"user_id": user_id}, @@ -151,7 +151,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn, self.get_expiration_ts_for_user, (user_id,) ) - yield self.runInteraction( + yield self.db.runInteraction( "set_account_validity_for_user", set_account_validity_for_user_txn ) @@ -167,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore): Raises: StoreError: The provided token is already set for another user. """ - yield self.simple_update_one( + yield self.db.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, @@ -184,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The ID of the user to which the token belongs. """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", @@ -203,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The renewal token associated with this user ID. """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", @@ -229,9 +229,9 @@ class RegistrationWorkerStore(SQLBaseStore): ) values = [False, now_ms, renew_at] txn.execute(sql, values) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - res = yield self.runInteraction( + res = yield self.db.runInteraction( "get_users_expiring_soon", select_users_txn, self.clock.time_msec(), @@ -250,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore): email_sent (bool): Flag which indicates whether a renewal email has been sent to this user. """ - yield self.simple_update_one( + yield self.db.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, @@ -265,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore): Args: user_id (str): ID of the user to remove from the account validity table. """ - yield self.simple_delete_one( + yield self.db.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", @@ -281,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", @@ -299,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self.simple_update_one( + return self.db.simple_update_one( table="users", keyvalues={"name": user.to_string()}, updatevalues={"admin": 1 if admin else 0}, @@ -316,7 +316,7 @@ class RegistrationWorkerStore(SQLBaseStore): ) txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0] @@ -332,7 +332,9 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if user 'user_type' is null or empty string """ - res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id) + res = yield self.db.runInteraction( + "is_real_user", self.is_real_user_txn, user_id + ) return res @cachedInlineCallbacks() @@ -345,13 +347,13 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if user is of type UserTypes.SUPPORT """ - res = yield self.runInteraction( + res = yield self.db.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) return res def is_real_user_txn(self, txn, user_id): - res = self.simple_select_one_onecol_txn( + res = self.db.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -361,7 +363,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res is None def is_support_user_txn(self, txn, user_id): - res = self.simple_select_one_onecol_txn( + res = self.db.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -380,7 +382,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return dict(txn) - return self.runInteraction("get_users_by_id_case_insensitive", f) + return self.db.runInteraction("get_users_by_id_case_insensitive", f) async def get_user_by_external_id( self, auth_provider: str, external_id: str @@ -394,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: the mxid of the user, or None if they are not known """ - return await self.simple_select_one_onecol( + return await self.db.simple_select_one_onecol( table="user_external_ids", keyvalues={"auth_provider": auth_provider, "external_id": external_id}, retcol="user_id", @@ -408,12 +410,12 @@ class RegistrationWorkerStore(SQLBaseStore): def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.runInteraction("count_users", _count_users) + ret = yield self.db.runInteraction("count_users", _count_users) return ret def count_daily_user_type(self): @@ -445,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore): results[row[0]] = row[1] return results - return self.runInteraction("count_daily_user_type", _count_daily_user_type) + return self.db.runInteraction("count_daily_user_type", _count_daily_user_type) @defer.inlineCallbacks def count_nonbridged_users(self): @@ -459,7 +461,7 @@ class RegistrationWorkerStore(SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_users", _count_users) + ret = yield self.db.runInteraction("count_users", _count_users) return ret @defer.inlineCallbacks @@ -468,12 +470,12 @@ class RegistrationWorkerStore(SQLBaseStore): def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.runInteraction("count_real_users", _count_users) + ret = yield self.db.runInteraction("count_real_users", _count_users) return ret @defer.inlineCallbacks @@ -503,7 +505,7 @@ class RegistrationWorkerStore(SQLBaseStore): return ( ( - yield self.runInteraction( + yield self.db.runInteraction( "find_next_generated_user_id", _find_next_generated_user_id ) ) @@ -520,7 +522,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[str|None]: user id or None if no user id/threepid mapping exists """ - user_id = yield self.runInteraction( + user_id = yield self.db.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) return user_id @@ -536,7 +538,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: user id or None if no user id/threepid mapping exists """ - ret = self.simple_select_one_txn( + ret = self.db.simple_select_one_txn( txn, "user_threepids", {"medium": medium, "address": address}, @@ -549,7 +551,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self.simple_upsert( + yield self.db.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, @@ -557,7 +559,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_get_threepids(self, user_id): - ret = yield self.simple_select_list( + ret = yield self.db.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], @@ -566,7 +568,7 @@ class RegistrationWorkerStore(SQLBaseStore): return ret def user_delete_threepid(self, user_id, medium, address): - return self.simple_delete( + return self.db.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", @@ -579,7 +581,7 @@ class RegistrationWorkerStore(SQLBaseStore): user_id: The user id to delete all threepids of """ - return self.simple_delete( + return self.db.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -601,7 +603,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ # We need to use an upsert, in case they user had already bound the # threepid - return self.simple_upsert( + return self.db.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -627,7 +629,7 @@ class RegistrationWorkerStore(SQLBaseStore): medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self.simple_select_list( + return self.db.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -648,7 +650,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred """ - return self.simple_delete( + return self.db.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -671,7 +673,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[list[str]]: Resolves to a list of identity servers """ - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", @@ -689,7 +691,7 @@ class RegistrationWorkerStore(SQLBaseStore): defer.Deferred(bool): The requested value. """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -756,13 +758,13 @@ class RegistrationWorkerStore(SQLBaseStore): sql += " LIMIT 1" txn.execute(sql, list(keyvalues.values())) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return None return rows[0] - return self.runInteraction( + return self.db.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn ) @@ -776,18 +778,18 @@ class RegistrationWorkerStore(SQLBaseStore): """ def delete_threepid_session_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, ) - return self.runInteraction( + return self.db.runInteraction( "delete_threepid_session", delete_threepid_session_txn ) @@ -857,7 +859,7 @@ class RegistrationBackgroundUpdateStore( (last_user, batch_size), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return True, 0 @@ -880,7 +882,7 @@ class RegistrationBackgroundUpdateStore( else: return False, len(rows) - end, nb_processed = yield self.runInteraction( + end, nb_processed = yield self.db.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) @@ -911,7 +913,7 @@ class RegistrationBackgroundUpdateStore( txn.executemany(sql, [(id_server,) for id_server in id_servers]) if id_servers: - yield self.runInteraction( + yield self.db.runInteraction( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) @@ -961,7 +963,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ next_id = self._access_tokens_id_gen.get_next() - yield self.simple_insert( + yield self.db.simple_insert( "access_tokens", { "id": next_id, @@ -1003,7 +1005,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Raises: StoreError if the user_id could not be registered. """ - return self.runInteraction( + return self.db.runInteraction( "register_user", self._register_user, user_id, @@ -1037,7 +1039,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Ensure that the guest user actually exists # ``allow_none=False`` makes this raise an exception # if the row isn't in the database. - self.simple_select_one_txn( + self.db.simple_select_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1045,7 +1047,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): allow_none=False, ) - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1059,7 +1061,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) else: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "users", values={ @@ -1114,7 +1116,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self.simple_insert( + return self.db.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1132,12 +1134,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def user_set_password_hash_txn(txn): - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_password_hash", user_set_password_hash_txn) + return self.db.runInteraction( + "user_set_password_hash", user_set_password_hash_txn + ) def user_set_consent_version(self, user_id, consent_version): """Updates the user table to record privacy policy consent @@ -1152,7 +1156,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1160,7 +1164,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_consent_version", f) + return self.db.runInteraction("user_set_consent_version", f) def user_set_consent_server_notice_sent(self, user_id, consent_version): """Updates the user table to record that we have sent the user a server @@ -1176,7 +1180,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1184,7 +1188,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_consent_server_notice_sent", f) + return self.db.runInteraction("user_set_consent_server_notice_sent", f) def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): """ @@ -1230,11 +1234,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return tokens_and_devices - return self.runInteraction("user_delete_access_tokens", f) + return self.db.runInteraction("user_delete_access_tokens", f) def delete_access_token(self, access_token): def f(txn): - self.simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -1242,11 +1246,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, self.get_user_by_access_token, (access_token,) ) - return self.runInteraction("delete_access_token", f) + return self.db.runInteraction("delete_access_token", f) @cachedInlineCallbacks() def is_guest(self, user_id): - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1261,7 +1265,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self.simple_insert( + return self.db.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", @@ -1274,7 +1278,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self.simple_delete( + return self.db.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", @@ -1285,7 +1289,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1315,7 +1319,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): - row = self.simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1333,7 +1337,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): 400, "This client_secret does not match the provided session_id" ) - row = self.simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id, "token": token}, @@ -1358,7 +1362,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Looks good. Validate the session - self.simple_update_txn( + self.db.simple_update_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1368,7 +1372,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return next_link # Return next_link if it exists - return self.runInteraction( + return self.db.runInteraction( "validate_threepid_session_txn", validate_threepid_session_txn ) @@ -1401,7 +1405,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): if validated_at: insertion_values["validated_at"] = validated_at - return self.simple_upsert( + return self.db.simple_upsert( table="threepid_validation_session", keyvalues={"session_id": session_id}, values={"last_send_attempt": send_attempt}, @@ -1439,7 +1443,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def start_or_continue_validation_session_txn(txn): # Create or update a validation session - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1452,7 +1456,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Create a new validation token with this session ID - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="threepid_validation_token", values={ @@ -1463,7 +1467,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) - return self.runInteraction( + return self.db.runInteraction( "start_or_continue_validation_session", start_or_continue_validation_session_txn, ) @@ -1478,7 +1482,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ return txn.execute(sql, (ts,)) - return self.runInteraction( + return self.db.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), @@ -1493,7 +1497,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): deactivated (bool): The value to set for `deactivated`. """ - yield self.runInteraction( + yield self.db.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, user_id, @@ -1501,7 +1505,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -1529,14 +1533,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) txn.execute(sql, []) - res = self.cursor_to_dict(txn) + res = self.db.cursor_to_dict(txn) if res: for user in res: self.set_expiration_date_for_user_txn( txn, user["name"], use_delta=True ) - yield self.runInteraction( + yield self.db.runInteraction( "get_users_with_no_expiration_date", select_users_with_no_expiration_date_txn, ) @@ -1560,7 +1564,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expiration_ts, ) - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, "account_validity", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py index f81f9279a1..1c07c7a425 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): def _store_rejections_txn(self, txn, event_id, reason): - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="rejections", values={ @@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore): ) def get_rejection_reason(self, event_id): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py index aa5e10538b..046c2b4845 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/data_stores/main/relations.py @@ -129,7 +129,7 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.runInteraction( + return self.db.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) @@ -223,7 +223,7 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.runInteraction( + return self.db.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) @@ -268,7 +268,7 @@ class RelationsWorkerStore(SQLBaseStore): if row: return row[0] - edit_id = yield self.runInteraction( + edit_id = yield self.db.runInteraction( "get_applicable_edit", _get_applicable_edit_txn ) @@ -318,7 +318,7 @@ class RelationsWorkerStore(SQLBaseStore): return bool(txn.fetchone()) - return self.runInteraction( + return self.db.runInteraction( "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) @@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore): aggregation_key = relation.get("key") - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="event_relations", values={ @@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore): redacted_event_id (str): The event that was redacted. """ - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index f309e3640c..a26ed47afc 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -54,7 +54,7 @@ class RoomWorkerStore(SQLBaseStore): Returns: A dict containing the room information, or None if the room is unknown. """ - return self.simple_select_one( + return self.db.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -63,7 +63,7 @@ class RoomWorkerStore(SQLBaseStore): ) def get_public_room_ids(self): - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", @@ -120,7 +120,7 @@ class RoomWorkerStore(SQLBaseStore): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.runInteraction("count_public_rooms", _count_public_rooms_txn) + return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) @defer.inlineCallbacks def get_largest_public_rooms( @@ -253,21 +253,21 @@ class RoomWorkerStore(SQLBaseStore): def _get_largest_public_rooms_txn(txn): txn.execute(sql, query_args) - results = self.cursor_to_dict(txn) + results = self.db.cursor_to_dict(txn) if not forwards: results.reverse() return results - ret_val = yield self.runInteraction( + ret_val = yield self.db.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) defer.returnValue(ret_val) @cached(max_entries=10000) def is_room_blocked(self, room_id): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", @@ -288,7 +288,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self.simple_select_one( + row = yield self.db.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -330,9 +330,9 @@ class RoomWorkerStore(SQLBaseStore): (room_id,), ) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - ret = yield self.runInteraction( + ret = yield self.db.runInteraction( "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) @@ -396,7 +396,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): (last_room, batch_size), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return True @@ -408,7 +408,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): ev = json.loads(row["json"]) retention_policy = json.dumps(ev["content"]) - self.simple_insert_txn( + self.db.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -430,7 +430,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): else: return False - end = yield self.runInteraction( + end = yield self.db.runInteraction( "insert_room_retention", _background_insert_retention_txn, ) @@ -461,7 +461,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): try: def store_room_txn(txn, next_id): - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "rooms", { @@ -471,7 +471,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) if is_public: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -482,7 +482,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction("store_room_txn", store_room_txn, next_id) + yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -490,14 +490,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="rooms", keyvalues={"room_id": room_id}, updatevalues={"is_public": is_public}, ) - entries = self.simple_select_list_txn( + entries = self.db.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -515,7 +515,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -528,7 +528,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( + yield self.db.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() @@ -555,7 +555,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="appservice_room_list", values={ @@ -568,7 +568,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): # We've already inserted, nothing to do. return else: - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="appservice_room_list", keyvalues={ @@ -578,7 +578,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - entries = self.simple_select_list_txn( + entries = self.db.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -596,7 +596,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -609,7 +609,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( + yield self.db.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, next_id, @@ -626,7 +626,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.runInteraction("get_rooms", f) + return self.db.runInteraction("get_rooms", f) def _store_room_topic_txn(self, txn, event): if hasattr(event, "content") and "topic" in event.content: @@ -660,7 +660,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): # Ignore the event if one of the value isn't an integer. return - self.simple_insert_txn( + self.db.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -679,7 +679,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self, room_id, event_id, user_id, reason, content, received_ts ): next_id = self._event_reports_id_gen.get_next() - return self.simple_insert( + return self.db.simple_insert( table="event_reports", values={ "id": next_id, @@ -712,7 +712,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): if prev_id == current_id: return defer.succeed([]) - return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms) + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) @defer.inlineCallbacks def block_room(self, room_id, user_id): @@ -725,14 +727,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): Returns: Deferred """ - yield self.simple_upsert( + yield self.db.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, insertion_values={"user_id": user_id}, desc="block_room", ) - yield self.runInteraction( + yield self.db.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, self.is_room_blocked, @@ -763,7 +765,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return local_media_mxcs, remote_media_mxcs - return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) + return self.db.runInteraction( + "get_media_ids_in_room", _get_media_mxcs_in_room_txn + ) def quarantine_media_ids_in_room(self, room_id, quarantined_by): """For a room loops through all events with media and quarantines @@ -802,7 +806,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return total_media_quarantined - return self.runInteraction( + return self.db.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -907,7 +911,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): txn.execute(sql, args) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) rooms_dict = {} for row in rows: @@ -923,7 +927,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): txn.execute(sql) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) # If a room isn't already in the dict (i.e. it doesn't have a retention # policy in its state), add it with a null policy. @@ -936,7 +940,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return rooms_dict - rooms = yield self.runInteraction( + rooms = yield self.db.runInteraction( "get_rooms_for_retention_period_in_range", get_rooms_for_retention_period_in_range_txn, ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index fe2428a281..7f4d02b25b 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -116,7 +116,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(query) return list(txn)[0][0] - count = yield self.runInteraction("get_known_servers", _transact) + count = yield self.db.runInteraction("get_known_servers", _transact) # We always know about ourselves, even if we have nothing in # room_memberships (for example, the server is new). @@ -128,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership column is up to date """ - pending_update = self.simple_select_one_txn( + pending_update = self.db.simple_select_one_txn( txn, table="background_updates", keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, @@ -144,7 +144,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): 15.0, run_as_background_process, "_check_safe_current_state_events_membership_updated", - self.runInteraction, + self.db.runInteraction, "_check_safe_current_state_events_membership_updated", self._check_safe_current_state_events_membership_updated_txn, ) @@ -161,7 +161,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) @@ -269,7 +269,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return res - return self.runInteraction("get_room_summary", _get_room_summary_txn) + return self.db.runInteraction("get_room_summary", _get_room_summary_txn) def _get_user_counts_in_room_txn(self, txn, room_id): """ @@ -339,7 +339,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): if not membership_list: return defer.succeed(None) - rooms = yield self.runInteraction( + rooms = yield self.db.runInteraction( "get_rooms_for_user_where_membership_is", self._get_rooms_for_user_where_membership_is_txn, user_id, @@ -392,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] + results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] if do_invite: sql = ( @@ -412,7 +412,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): stream_ordering=r["stream_ordering"], membership=Membership.INVITE, ) - for r in self.cursor_to_dict(txn) + for r in self.db.cursor_to_dict(txn) ) return results @@ -603,7 +603,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -643,7 +643,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self.execute("is_host_joined", None, sql, room_id, like_clause) + rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -683,7 +683,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self.execute("was_host_joined", None, sql, room_id, like_clause) + rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -753,7 +753,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): rows = txn.fetchall() return rows[0][0] - count = yield self.runInteraction("did_forget_membership", f) + count = yield self.db.runInteraction("did_forget_membership", f) return count == 0 @cached() @@ -790,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (user_id,)) return set(row[0] for row in txn if row[1] == 0) - return self.runInteraction( + return self.db.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) @@ -805,7 +805,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred[set[str]]: Set of room IDs. """ - room_ids = yield self.simple_select_onecol( + room_ids = yield self.db.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -820,7 +820,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Get user_id and membership of a set of event IDs. """ - return self.simple_select_many_batch( + return self.db.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -874,7 +874,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return 0 @@ -915,7 +915,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): return len(rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn ) @@ -971,7 +971,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): # string, which will compare before all room IDs correctly. last_processed_room = progress.get("last_processed_room", "") - row_count, finished = yield self.runInteraction( + row_count, finished = yield self.db.runInteraction( "_background_current_state_membership_update", _background_current_state_membership_txn, last_processed_room, @@ -990,7 +990,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="room_memberships", values=[ @@ -1028,7 +1028,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): is_mine = self.hs.is_mine_id(event.state_key) if is_new_state and is_mine: if event.membership == Membership.INVITE: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_invites", values={ @@ -1068,7 +1068,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): txn.execute(sql, (stream_ordering, True, room_id, user_id)) with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) + yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" @@ -1091,7 +1091,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): txn, self.get_forgotten_rooms_for_user, (user_id,) ) - return self.runInteraction("forget_membership", f) + return self.db.runInteraction("forget_membership", f) class _JoinedHostsCache(object): diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index f735cf095c..55a604850e 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -93,7 +93,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): # store_search_entries_txn with a generator function, but that # would mean having two cursors open on the database at once. # Instead we just build a list of results. - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return 0 @@ -159,7 +159,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): return len(event_search_rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn ) @@ -206,7 +206,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): - yield self.runWithConnection(create_index) + yield self.db.runWithConnection(create_index) yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) return 1 @@ -237,12 +237,12 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): ) conn.set_session(autocommit=False) - yield self.runWithConnection(create_index) + yield self.db.runWithConnection(create_index) pg = dict(progress) pg["have_added_indexes"] = True - yield self.runInteraction( + yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, @@ -280,7 +280,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): return len(rows), True - num_rows, finished = yield self.runInteraction( + num_rows, finished = yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn ) @@ -441,7 +441,9 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self.execute("search_msgs", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_msgs", self.db.cursor_to_dict, sql, *args + ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -455,8 +457,8 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self.execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + count_results = yield self.db.execute( + "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -586,7 +588,9 @@ class SearchStore(SearchBackgroundUpdateStore): args.append(limit) - results = yield self.execute("search_rooms", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_rooms", self.db.cursor_to_dict, sql, *args + ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -600,8 +604,8 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self.execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + count_results = yield self.db.execute( + "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -686,7 +690,7 @@ class SearchStore(SearchBackgroundUpdateStore): return highlight_words - return self.runInteraction("_find_highlights", f) + return self.db.runInteraction("_find_highlights", f) def _to_postgres_options(options_dict): diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py index f3da29ce14..563216b63c 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -48,7 +48,7 @@ class SignatureWorkerStore(SQLBaseStore): for event_id in event_ids } - return self.runInteraction("get_event_reference_hashes", f) + return self.db.runInteraction("get_event_reference_hashes", f) @defer.inlineCallbacks def add_event_hashes(self, event_ids): @@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore): } ) - self.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) + self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 2b33ec1a35..851e81d6b3 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): count = 0 while next_group: - next_group = self.simple_select_one_onecol_txn( + next_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): ): break - next_group = self.simple_select_one_onecol_txn( + next_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -348,7 +348,9 @@ class StateGroupWorkerStore( (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn } - return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) + return self.db.runInteraction( + "get_current_state_ids", _get_current_state_ids_txn + ) # FIXME: how should this be cached? def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): @@ -392,7 +394,7 @@ class StateGroupWorkerStore( return results - return self.runInteraction( + return self.db.runInteraction( "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) @@ -431,7 +433,7 @@ class StateGroupWorkerStore( """ def _get_state_group_delta_txn(txn): - prev_group = self.simple_select_one_onecol_txn( + prev_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, @@ -442,7 +444,7 @@ class StateGroupWorkerStore( if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self.simple_select_list_txn( + delta_ids = self.db.simple_select_list_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, @@ -454,7 +456,9 @@ class StateGroupWorkerStore( {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, ) - return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) + return self.db.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -540,7 +544,7 @@ class StateGroupWorkerStore( chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: - res = yield self.runInteraction( + res = yield self.db.runInteraction( "_get_state_groups_from_groups", self._get_state_groups_from_groups_txn, chunk, @@ -644,7 +648,7 @@ class StateGroupWorkerStore( @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", @@ -661,7 +665,7 @@ class StateGroupWorkerStore( def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, @@ -902,7 +906,7 @@ class StateGroupWorkerStore( state_group = self.database_engine.get_next_state_group_id(txn) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_groups", values={"id": state_group, "room_id": room_id, "event_id": event_id}, @@ -911,7 +915,7 @@ class StateGroupWorkerStore( # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if prev_group: - is_in_db = self.simple_select_one_onecol_txn( + is_in_db = self.db.simple_select_one_onecol_txn( txn, table="state_groups", keyvalues={"id": prev_group}, @@ -926,13 +930,13 @@ class StateGroupWorkerStore( potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_group_edges", values={"state_group": state_group, "prev_state_group": prev_group}, ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -947,7 +951,7 @@ class StateGroupWorkerStore( ], ) else: - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -993,7 +997,7 @@ class StateGroupWorkerStore( return state_group - return self.runInteraction("store_state_group", _store_state_group_txn) + return self.db.runInteraction("store_state_group", _store_state_group_txn) @defer.inlineCallbacks def get_referenced_state_groups(self, state_groups): @@ -1007,7 +1011,7 @@ class StateGroupWorkerStore( referenced. """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, @@ -1065,7 +1069,7 @@ class StateBackgroundUpdateStore( batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self.execute( + rows = yield self.db.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -1135,13 +1139,13 @@ class StateBackgroundUpdateStore( if prev_state.get(key, None) != value } - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_group_edges", values={ @@ -1150,13 +1154,13 @@ class StateBackgroundUpdateStore( }, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1183,7 +1187,7 @@ class StateBackgroundUpdateStore( return False, batch_size - finished, result = yield self.runInteraction( + finished, result = yield self.db.runInteraction( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) @@ -1218,7 +1222,7 @@ class StateBackgroundUpdateStore( ) txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - yield self.runWithConnection(reindex_txn) + yield self.db.runWithConnection(reindex_txn) yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) @@ -1263,7 +1267,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): state_groups[event.event_id] = context.state_group - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_to_state_groups", values=[ diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py index 03b908026b..12c982cb26 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py @@ -98,14 +98,14 @@ class StateDeltasStore(SQLBaseStore): ORDER BY stream_id ASC """ txn.execute(sql, (prev_stream_id, clipped_stream_id)) - return clipped_stream_id, self.cursor_to_dict(txn) + return clipped_stream_id, self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self.simple_select_one_onecol_txn( + return self.db.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, @@ -113,7 +113,7 @@ class StateDeltasStore(SQLBaseStore): ) def get_max_stream_id_in_current_state_deltas(self): - return self.runInteraction( + return self.db.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 3aeba859fd..974ffc15bd 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -117,7 +117,7 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = yield self.runInteraction( + users_to_work_on = yield self.db.runInteraction( "_populate_stats_process_users", _get_next_batch ) @@ -130,7 +130,7 @@ class StatsStore(StateDeltasStore): yield self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - yield self.runInteraction( + yield self.db.runInteraction( "populate_stats_process_users", self._background_update_progress_txn, "populate_stats_process_users", @@ -160,7 +160,7 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = yield self.runInteraction( + rooms_to_work_on = yield self.db.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch ) @@ -173,7 +173,7 @@ class StatsStore(StateDeltasStore): yield self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - yield self.runInteraction( + yield self.db.runInteraction( "_populate_stats_process_rooms", self._background_update_progress_txn, "populate_stats_process_rooms", @@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore): """ Returns the stats processor positions. """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -215,7 +215,7 @@ class StatsStore(StateDeltasStore): if field and "\0" in field: fields[col] = None - return self.simple_upsert( + return self.db.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, @@ -236,7 +236,7 @@ class StatsStore(StateDeltasStore): Deferred[list[dict]], where the dict has the keys of ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self.runInteraction( + return self.db.runInteraction( "get_statistics_for_subject", self._get_statistics_for_subject_txn, stats_type, @@ -257,7 +257,7 @@ class StatsStore(StateDeltasStore): ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] ) - slice_list = self.simple_select_list_paginate_txn( + slice_list = self.db.simple_select_list_paginate_txn( txn, table + "_historical", {id_col: stats_id}, @@ -282,7 +282,7 @@ class StatsStore(StateDeltasStore): "name", "topic", "canonical_alias", "avatar", "join_rules", "history_visibility" """ - return self.simple_select_one( + return self.db.simple_select_one( "room_stats_state", {"room_id": room_id}, retcols=( @@ -308,7 +308,7 @@ class StatsStore(StateDeltasStore): """ table, id_col = TYPE_TO_TABLE[stats_type] - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", @@ -344,14 +344,14 @@ class StatsStore(StateDeltasStore): complete_with_stream_id=stream_id, ) - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": stream_id}, ) - return self.runInteraction( + return self.db.runInteraction( "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) @@ -382,7 +382,7 @@ class StatsStore(StateDeltasStore): Does not work with per-slice fields. """ - return self.runInteraction( + return self.db.runInteraction( "update_stats_delta", self._update_stats_delta_txn, ts, @@ -517,17 +517,17 @@ class StatsStore(StateDeltasStore): else: self.database_engine.lock_table(txn, table) retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self.simple_select_one_txn( + current_row = self.db.simple_select_one_txn( txn, table, keyvalues, retcols, allow_none=True ) if current_row is None: merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self.simple_insert_txn(txn, table, merged_dict) + self.db.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): current_row[key] += val current_row.update(absolutes) - self.simple_update_one_txn(txn, table, keyvalues, current_row) + self.db.simple_update_one_txn(txn, table, keyvalues, current_row) def _upsert_copy_from_table_with_additive_relatives_txn( self, @@ -614,11 +614,11 @@ class StatsStore(StateDeltasStore): txn.execute(sql, qargs) else: self.database_engine.lock_table(txn, into_table) - src_row = self.simple_select_one_txn( + src_row = self.db.simple_select_one_txn( txn, src_table, keyvalues, copy_columns ) all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self.simple_select_one_txn( + dest_current_row = self.db.simple_select_one_txn( txn, into_table, keyvalues=all_dest_keyvalues, @@ -634,11 +634,11 @@ class StatsStore(StateDeltasStore): **src_row, **additive_relatives, } - self.simple_insert_txn(txn, into_table, merged_dict) + self.db.simple_insert_txn(txn, into_table, merged_dict) else: for (key, val) in additive_relatives.items(): src_row[key] = dest_current_row[key] + val - self.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): """Fetches the counts of events in the given range of stream IDs. @@ -652,7 +652,7 @@ class StatsStore(StateDeltasStore): changes. """ - return self.runInteraction( + return self.db.runInteraction( "stats_incremental_total_events_and_bytes", self.get_changes_room_total_events_and_bytes_txn, min_pos, @@ -735,7 +735,7 @@ class StatsStore(StateDeltasStore): def _fetch_current_state_stats(txn): pos = self.get_room_max_stream_ordering() - rows = self.simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="current_state_events", column="type", @@ -791,7 +791,7 @@ class StatsStore(StateDeltasStore): current_state_events_count, users_in_room, pos, - ) = yield self.runInteraction( + ) = yield self.db.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) @@ -866,7 +866,7 @@ class StatsStore(StateDeltasStore): (count,) = txn.fetchone() return count, pos - joined_rooms, pos = yield self.runInteraction( + joined_rooms, pos = yield self.db.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 60487c4559..2ff8c57109 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -255,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() - event_cache_prefill, min_event_val = self.get_cache_dict( + event_cache_prefill, min_event_val = self.db.get_cache_dict( db_conn, "events", entity_column="room_id", @@ -400,7 +400,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.runInteraction("get_room_events_stream_for_room", f) + rows = yield self.db.runInteraction("get_room_events_stream_for_room", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -450,7 +450,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows - rows = yield self.runInteraction("get_membership_changes_for_user", f) + rows = yield self.db.runInteraction("get_membership_changes_for_user", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -511,7 +511,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.runInteraction( + rows, token = yield self.db.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -548,7 +548,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.runInteraction("get_room_event_after_stream_ordering", _f) + return self.db.runInteraction("get_room_event_after_stream_ordering", _f) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): @@ -562,7 +562,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if room_id is None: return "s%d" % (token,) else: - topo = yield self.runInteraction( + topo = yield self.db.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) return "t%d-%d" % (topo, token) @@ -576,7 +576,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "s%d" stream token. """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) @@ -589,7 +589,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "t%d-%d" topological token. """ - return self.simple_select_one( + return self.db.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), @@ -613,7 +613,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self.execute( + return self.db.execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) @@ -667,7 +667,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = yield self.runInteraction( + results = yield self.db.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -709,7 +709,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self.simple_select_one_txn( + results = self.db.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -788,7 +788,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.runInteraction( + upper_bound, event_ids = yield self.db.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) @@ -797,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events def get_federation_out_pos(self, typ): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, @@ -805,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) def update_federation_out_pos(self, typ, stream_id): - return self.simple_update_one( + return self.db.simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, @@ -956,7 +956,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.runInteraction( + rows, token = yield self.db.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index 85012403be..2aa1bafd48 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore): tag strings to tag content. """ - deferred = self.simple_select_list( + deferred = self.db.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) @@ -78,7 +78,7 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - tag_ids = yield self.runInteraction( + tag_ids = yield self.db.runInteraction( "get_all_updated_tags", get_all_updated_tags_txn ) @@ -98,7 +98,7 @@ class TagsWorkerStore(AccountDataWorkerStore): batch_size = 50 results = [] for i in range(0, len(tag_ids), batch_size): - tags = yield self.runInteraction( + tags = yield self.db.runInteraction( "get_all_updated_tag_content", get_tag_content, tag_ids[i : i + batch_size], @@ -135,7 +135,9 @@ class TagsWorkerStore(AccountDataWorkerStore): if not changed: return {} - room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn) + room_ids = yield self.db.runInteraction( + "get_updated_tags", get_updated_tags_txn + ) results = {} if room_ids: @@ -153,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore): Returns: A deferred list of string tags. """ - return self.simple_select_list( + return self.db.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), @@ -178,7 +180,7 @@ class TagsStore(TagsWorkerStore): content_json = json.dumps(content) def add_tag_txn(txn, next_id): - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, @@ -187,7 +189,7 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("add_tag", add_tag_txn, next_id) + yield self.db.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -210,7 +212,7 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("remove_tag", remove_tag_txn, next_id) + yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index c162f3ea16..c0d155a43c 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -77,7 +77,7 @@ class TransactionStore(SQLBaseStore): this transaction or a 2-tuple of (int, dict) """ - return self.runInteraction( + return self.db.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -85,7 +85,7 @@ class TransactionStore(SQLBaseStore): ) def _get_received_txn_response(self, txn, transaction_id, origin): - result = self.simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, @@ -119,7 +119,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ - return self.simple_insert( + return self.db.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, @@ -148,7 +148,7 @@ class TransactionStore(SQLBaseStore): if result is not SENTINEL: return result - result = yield self.runInteraction( + result = yield self.db.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, @@ -160,7 +160,7 @@ class TransactionStore(SQLBaseStore): return result def _get_destination_retry_timings(self, txn, destination): - result = self.simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -187,7 +187,7 @@ class TransactionStore(SQLBaseStore): """ self._destination_retry_cache.pop(destination, None) - return self.runInteraction( + return self.db.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -227,7 +227,7 @@ class TransactionStore(SQLBaseStore): # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. - prev_row = self.simple_select_one_txn( + prev_row = self.db.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -236,7 +236,7 @@ class TransactionStore(SQLBaseStore): ) if not prev_row: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="destinations", values={ @@ -247,7 +247,7 @@ class TransactionStore(SQLBaseStore): }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, "destinations", keyvalues={"destination": destination}, @@ -270,4 +270,6 @@ class TransactionStore(SQLBaseStore): def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn) + return self.db.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 1a85aabbfb..7118bd62f3 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -85,7 +85,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ txn.execute(sql) rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms # If search all users is on, get all the users we want to add. @@ -100,13 +100,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("SELECT name FROM users") users = [{"user_id": x[0]} for x in txn.fetchall()] - self.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) yield self._end_background_update("populate_user_directory_createtables") return 1 @@ -116,7 +116,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ Update the user directory stream position, then clean up the old tables. """ - position = yield self.simple_select_one_onecol( + position = yield self.db.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) yield self.update_user_directory_stream_pos(position) @@ -126,7 +126,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory_cleanup", _delete_staging_area ) @@ -170,7 +170,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore return rooms_to_work_on - rooms_to_work_on = yield self.runInteraction( + rooms_to_work_on = yield self.db.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) @@ -243,10 +243,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore to_insert.clear() # We've finished a room. Delete it from the table. - yield self.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) # Update the remaining counter. progress["remaining"] -= 1 - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory", self._background_update_progress_txn, "populate_user_directory_process_rooms", @@ -291,7 +291,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore return users_to_work_on - users_to_work_on = yield self.runInteraction( + users_to_work_on = yield self.db.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) @@ -312,10 +312,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) # We've finished processing a user. Delete it from the table. - yield self.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) # Update the remaining counter. progress["remaining"] -= 1 - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory", self._background_update_progress_txn, "populate_user_directory_process_users", @@ -361,7 +361,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _update_profile_in_user_dir_txn(txn): - new_entry = self.simple_upsert_txn( + new_entry = self.db.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -435,7 +435,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id}, @@ -448,7 +448,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.runInteraction( + return self.db.runInteraction( "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) @@ -462,7 +462,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _add_users_who_share_room_txn(txn): - self.simple_upsert_many_txn( + self.db.simple_upsert_many_txn( txn, table="users_who_share_private_rooms", key_names=["user_id", "other_user_id", "room_id"], @@ -474,7 +474,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore value_values=None, ) - return self.runInteraction( + return self.db.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) @@ -489,7 +489,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore def _add_users_in_public_rooms_txn(txn): - self.simple_upsert_many_txn( + self.db.simple_upsert_many_txn( txn, table="users_in_public_rooms", key_names=["user_id", "room_id"], @@ -498,7 +498,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore value_values=None, ) - return self.runInteraction( + return self.db.runInteraction( "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) @@ -513,13 +513,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("DELETE FROM users_who_share_private_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) - return self.runInteraction( + return self.db.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @cached() def get_user_in_directory(self, user_id): - return self.simple_select_one( + return self.db.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -528,7 +528,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) def update_user_directory_stream_pos(self, stream_id): - return self.simple_update_one( + return self.db.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, @@ -547,42 +547,42 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id} ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id}, ) txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) + return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) @defer.inlineCallbacks def get_users_in_dir_due_to_room(self, room_id): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self.simple_select_onecol( + user_ids_share_pub = yield self.db.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self.simple_select_onecol( + user_ids_share_priv = yield self.db.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -605,23 +605,23 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """ def _remove_user_who_share_room_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id, "room_id": room_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_user_who_share_room", _remove_user_who_share_room_txn ) @@ -636,14 +636,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns: list: user_id """ - rows = yield self.simple_select_onecol( + rows = yield self.db.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self.simple_select_onecol( + pub_rows = yield self.db.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -674,14 +674,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) f2 USING (room_id) """ - rows = yield self.execute( + rows = yield self.db.execute( "get_rooms_in_common_for_users", None, sql, user_id, other_user_id ) return [room_id for room_id, in rows] def get_user_directory_stream_pos(self): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", @@ -786,7 +786,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self.execute("search_user_dir", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_user_dir", self.db.cursor_to_dict, sql, *args + ) limited = len(results) > limit diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index 37860af070..af8025bc17 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if the user has requested erasure """ - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", @@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, @@ -88,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.runInteraction("mark_user_erased", f) + return self.db.runInteraction("mark_user_erased", f) diff --git a/synapse/storage/database.py b/synapse/storage/database.py new file mode 100644 index 0000000000..c2e121a001 --- /dev/null +++ b/synapse/storage/database.py @@ -0,0 +1,1485 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +import sys +import time +from typing import Iterable, Tuple + +from six import iteritems, iterkeys, itervalues +from six.moves import intern, range + +from prometheus_client import Histogram + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.logging.context import LoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.util.stringutils import exception_to_unicode + +# import a function which will return a monotonic time, in seconds +try: + # on python 3, use time.monotonic, since time.clock can go backwards + from time import monotonic as monotonic_time +except ImportError: + # ... but python 2 doesn't have it + from time import clock as monotonic_time + +logger = logging.getLogger(__name__) + +try: + MAX_TXN_ID = sys.maxint - 1 +except AttributeError: + # python 3 does not have a maximum int value + MAX_TXN_ID = 2 ** 63 - 1 + +sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") +perf_logger = logging.getLogger("synapse.storage.TIME") + +sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") + +sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) +sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) + + +# Unique indexes which have been added in background updates. Maps from table name +# to the name of the background update which added the unique index to that table. +# +# This is used by the upsert logic to figure out which tables are safe to do a proper +# UPSERT on: until the relevant background update has completed, we +# have to emulate an upsert by locking the table. +# +UNIQUE_INDEX_BACKGROUND_UPDATES = { + "user_ips": "user_ips_device_unique_index", + "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", + "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", + "event_search": "event_search_event_id_idx", +} + + +class LoggingTransaction(object): + """An object that almost-transparently proxies for the 'txn' object + passed to the constructor. Adds logging and metrics to the .execute() + method. + + Args: + txn: The database transcation object to wrap. + name (str): The name of this transactions for logging. + database_engine (Sqlite3Engine|PostgresEngine) + after_callbacks(list|None): A list that callbacks will be appended to + that have been added by `call_after` which should be run on + successful completion of the transaction. None indicates that no + callbacks should be allowed to be scheduled to run. + exception_callbacks(list|None): A list that callbacks will be appended + to that have been added by `call_on_exception` which should be run + if transaction ends with an error. None indicates that no callbacks + should be allowed to be scheduled to run. + """ + + __slots__ = [ + "txn", + "name", + "database_engine", + "after_callbacks", + "exception_callbacks", + ] + + def __init__( + self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None + ): + object.__setattr__(self, "txn", txn) + object.__setattr__(self, "name", name) + object.__setattr__(self, "database_engine", database_engine) + object.__setattr__(self, "after_callbacks", after_callbacks) + object.__setattr__(self, "exception_callbacks", exception_callbacks) + + def call_after(self, callback, *args, **kwargs): + """Call the given callback on the main twisted thread after the + transaction has finished. Used to invalidate the caches on the + correct thread. + """ + self.after_callbacks.append((callback, args, kwargs)) + + def call_on_exception(self, callback, *args, **kwargs): + self.exception_callbacks.append((callback, args, kwargs)) + + def __getattr__(self, name): + return getattr(self.txn, name) + + def __setattr__(self, name, value): + setattr(self.txn, name, value) + + def __iter__(self): + return self.txn.__iter__() + + def execute_batch(self, sql, args): + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import execute_batch + + self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) + else: + for val in args: + self.execute(sql, val) + + def execute(self, sql, *args): + self._do_execute(self.txn.execute, sql, *args) + + def executemany(self, sql, *args): + self._do_execute(self.txn.executemany, sql, *args) + + def _make_sql_one_line(self, sql): + "Strip newlines out of SQL so that the loggers in the DB are on one line" + return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + + def _do_execute(self, func, sql, *args): + sql = self._make_sql_one_line(sql) + + # TODO(paul): Maybe use 'info' and 'debug' for values? + sql_logger.debug("[SQL] {%s} %s", self.name, sql) + + sql = self.database_engine.convert_param_style(sql) + if args: + try: + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) + except Exception: + # Don't let logging failures stop SQL from working + pass + + start = time.time() + + try: + return func(sql, *args) + except Exception as e: + logger.debug("[SQL FAIL] {%s} %s", self.name, e) + raise + finally: + secs = time.time() - start + sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) + sql_query_timer.labels(sql.split()[0]).observe(secs) + + +class PerformanceCounters(object): + def __init__(self): + self.current_counters = {} + self.previous_counters = {} + + def update(self, key, duration_secs): + count, cum_time = self.current_counters.get(key, (0, 0)) + count += 1 + cum_time += duration_secs + self.current_counters[key] = (count, cum_time) + + def interval(self, interval_duration_secs, limit=3): + counters = [] + for name, (count, cum_time) in iteritems(self.current_counters): + prev_count, prev_time = self.previous_counters.get(name, (0, 0)) + counters.append( + ( + (cum_time - prev_time) / interval_duration_secs, + count - prev_count, + name, + ) + ) + + self.previous_counters = dict(self.current_counters) + + counters.sort(reverse=True) + + top_n_counters = ", ".join( + "%s(%d): %.3f%%" % (name, count, 100 * ratio) + for ratio, count, name in counters[:limit] + ) + + return top_n_counters + + +class Database(object): + _TXN_ID = 0 + + def __init__(self, hs): + self.hs = hs + self._clock = hs.get_clock() + self._db_pool = hs.get_db_pool() + + self._previous_txn_total_time = 0 + self._current_txn_total_time = 0 + self._previous_loop_ts = 0 + + # TODO(paul): These can eventually be removed once the metrics code + # is running in mainline, and we have some nice monitoring frontends + # to watch it + self._txn_perf_counters = PerformanceCounters() + + self.database_engine = hs.database_engine + + # A set of tables that are not safe to use native upserts in. + self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) + + # We add the user_directory_search table to the blacklist on SQLite + # because the existing search table does not have an index, making it + # unsafe to use native upserts. + if isinstance(self.database_engine, Sqlite3Engine): + self._unsafe_to_upsert_tables.add("user_directory_search") + + if self.database_engine.can_native_upsert: + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later( + 0.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + self.rand = random.SystemRandom() + + @defer.inlineCallbacks + def _check_safe_to_upsert(self): + """ + Is it safe to use native UPSERT? + + If there are background updates, we will need to wait, as they may be + the addition of indexes that set the UNIQUE constraint that we require. + + If the background updates have not completed, wait 15 sec and check again. + """ + updates = yield self.simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ) + updates = [x["update_name"] for x in updates] + + for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): + if update_name not in updates: + logger.debug("Now safe to upsert in %s", table) + self._unsafe_to_upsert_tables.discard(table) + + # If there's any updates still running, reschedule to run. + if updates: + self._clock.call_later( + 15.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + def start_profiling(self): + self._previous_loop_ts = monotonic_time() + + def loop(): + curr = self._current_txn_total_time + prev = self._previous_txn_total_time + self._previous_txn_total_time = curr + + time_now = monotonic_time() + time_then = self._previous_loop_ts + self._previous_loop_ts = time_now + + duration = time_now - time_then + ratio = (curr - prev) / duration + + top_three_counters = self._txn_perf_counters.interval(duration, limit=3) + + perf_logger.info( + "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters + ) + + self._clock.looping_call(loop, 10000) + + def new_transaction( + self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs + ): + start = monotonic_time() + txn_id = self._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) + + name = "%s-%x" % (desc, txn_id) + + transaction_logger.debug("[TXN START] {%s}", name) + + try: + i = 0 + N = 5 + while True: + cursor = LoggingTransaction( + conn.cursor(), + name, + self.database_engine, + after_callbacks, + exception_callbacks, + ) + try: + r = func(cursor, *args, **kwargs) + conn.commit() + return r + except self.database_engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warning( + "[TXN OPERROR] {%s} %s %d/%d", + name, + exception_to_unicode(e), + i, + N, + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) + ) + continue + raise + except self.database_engine.module.DatabaseError as e: + if self.database_engine.is_deadlock(e): + logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", + name, + exception_to_unicode(e1), + ) + continue + raise + finally: + # we're either about to retry with a new cursor, or we're about to + # release the connection. Once we release the connection, it could + # get used for another query, which might do a conn.rollback(). + # + # In the latter case, even though that probably wouldn't affect the + # results of this transaction, python's sqlite will reset all + # statements on the connection [1], which will make our cursor + # invalid [2]. + # + # In any case, continuing to read rows after commit()ing seems + # dubious from the PoV of ACID transactional semantics + # (sqlite explicitly says that once you commit, you may see rows + # from subsequent updates.) + # + # In psycopg2, cursors are essentially a client-side fabrication - + # all the data is transferred to the client side when the statement + # finishes executing - so in theory we could go on streaming results + # from the cursor, but attempting to do so would make us + # incompatible with sqlite, so let's make sure we're not doing that + # by closing the cursor. + # + # (*named* cursors in psycopg2 are different and are proper server- + # side things, but (a) we don't use them and (b) they are implicitly + # closed by ending the transaction anyway.) + # + # In short, if we haven't finished with the cursor yet, that's a + # problem waiting to bite us. + # + # TL;DR: we're done with the cursor, so we can close it. + # + # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 + # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 + cursor.close() + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) + raise + finally: + end = monotonic_time() + duration = end - start + + LoggingContext.current_context().add_database_transaction(duration) + + transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) + + self._current_txn_total_time += duration + self._txn_perf_counters.update(desc, duration) + sql_txn_timer.labels(desc).observe(duration) + + @defer.inlineCallbacks + def runInteraction(self, desc, func, *args, **kwargs): + """Starts a transaction on the database and runs a given function + + Arguments: + desc (str): description of the transaction, for logging and metrics + func (func): callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. + + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + after_callbacks = [] + exception_callbacks = [] + + if LoggingContext.current_context() == LoggingContext.sentinel: + logger.warning("Starting db txn '%s' from sentinel context", desc) + + try: + result = yield self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + **kwargs + ) + + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + except: # noqa: E722, as we reraise the exception this is fine. + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + return result + + @defer.inlineCallbacks + def runWithConnection(self, func, *args, **kwargs): + """Wraps the .runWithConnection() method on the underlying db_pool. + + Arguments: + func (func): callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + parent_context = LoggingContext.current_context() + if parent_context == LoggingContext.sentinel: + logger.warning( + "Starting db connection from sentinel context: metrics will be lost" + ) + parent_context = None + + start_time = monotonic_time() + + def inner_func(conn, *args, **kwargs): + with LoggingContext("runWithConnection", parent_context) as context: + sched_duration_sec = monotonic_time() - start_time + sql_scheduling_timer.observe(sched_duration_sec) + context.add_database_scheduled(sched_duration_sec) + + if self.database_engine.is_connection_closed(conn): + logger.debug("Reconnecting closed database connection") + conn.reconnect() + + return func(conn, *args, **kwargs) + + result = yield make_deferred_yieldable( + self._db_pool.runWithConnection(inner_func, *args, **kwargs) + ) + + return result + + @staticmethod + def cursor_to_dict(cursor): + """Converts a SQL cursor into an list of dicts. + + Args: + cursor : The DBAPI cursor which has executed a query. + Returns: + A list of dicts where the key is the column header. + """ + col_headers = list(intern(str(column[0])) for column in cursor.description) + results = list(dict(zip(col_headers, row)) for row in cursor) + return results + + def execute(self, desc, decoder, query, *args): + """Runs a single query for a result set. + + Args: + decoder - The function which can resolve the cursor results to + something meaningful. + query - The query string to execute + *args - Query args. + Returns: + The result of decoder(results) + """ + + def interaction(txn): + txn.execute(query, args) + if decoder: + return decoder(txn) + else: + return txn.fetchall() + + return self.runInteraction(desc, interaction) + + # "Simple" SQL API methods that operate on a single table with no JOINs, + # no complex WHERE clauses, just a dict of values for columns. + + @defer.inlineCallbacks + def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + """Executes an INSERT query on the named table. + + Args: + table : string giving the table name + values : dict of new column names and values for them + or_ignore : bool stating whether an exception should be raised + when a conflicting row already exists. If True, False will be + returned by the function instead + desc : string giving a description of the transaction + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True + """ + try: + yield self.runInteraction(desc, self.simple_insert_txn, table, values) + except self.database_engine.module.IntegrityError: + # We have to do or_ignore flag at this layer, since we can't reuse + # a cursor after we receive an error from the db. + if not or_ignore: + raise + return False + return True + + @staticmethod + def simple_insert_txn(txn, table, values): + keys, vals = zip(*values.items()) + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), + ) + + txn.execute(sql, vals) + + def simple_insert_many(self, table, values, desc): + return self.runInteraction(desc, self.simple_insert_many_txn, table, values) + + @staticmethod + def simple_insert_many_txn(txn, table, values): + if not values: + return + + # This is a *slight* abomination to get a list of tuples of key names + # and a list of tuples of value names. + # + # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] + # + # The sort is to ensure that we don't rely on dictionary iteration + # order. + keys, vals = zip( + *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + ) + + for k in keys: + if k != keys[0]: + raise RuntimeError("All items must have the same keys") + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys[0]), + ", ".join("?" for _ in keys[0]), + ) + + txn.executemany(sql, vals) + + @defer.inlineCallbacks + def simple_upsert( + self, + table, + keyvalues, + values, + insertion_values={}, + desc="simple_upsert", + lock=True, + ): + """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key columns and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self.simple_upsert_txn, + table, + keyvalues, + values, + insertion_values, + lock=lock, + ) + return result + except self.database_engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warning( + "IntegrityError when upserting into %s; retrying: %s", table, e + ) + + def simple_upsert_txn( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Pick the UPSERT method which works best on the platform. Either the + native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + + Args: + txn: The transaction to use. + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + None or bool: Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self.simple_upsert_txn_native_upsert( + txn, table, keyvalues, values, insertion_values=insertion_values + ) + else: + return self.simple_upsert_txn_emulated( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + lock=lock, + ) + + def simple_upsert_txn_emulated( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + bool: Return True if a new entry was created, False if an existing + one was updated. + """ + # We need to lock the table :(, unless we're *really* careful + if lock: + self.database_engine.lock_table(txn, table) + + def _getwhere(key): + # If the value we're passing in is None (aka NULL), we need to use + # IS, not =, as NULL = NULL equals NULL (False). + if keyvalues[key] is None: + return "%s IS ?" % (key,) + else: + return "%s = ?" % (key,) + + if not values: + # If `values` is empty, then all of the values we care about are in + # the unique key, so there is nothing to UPDATE. We can just do a + # SELECT instead to see if it exists. + sql = "SELECT 1 FROM %s WHERE %s" % ( + table, + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(keyvalues.values()) + txn.execute(sql, sqlargs) + if txn.fetchall(): + # We have an existing record. + return False + else: + # First try to update. + sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in values), + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(values.values()) + list(keyvalues.values()) + + txn.execute(sql, sqlargs) + if txn.rowcount > 0: + # successfully updated at least one row. + return False + + # We didn't find any existing rows, so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) + + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ) + txn.execute(sql, list(allvalues.values())) + # successfully inserted + return True + + def simple_upsert_txn_native_upsert( + self, txn, table, keyvalues, values, insertion_values={} + ): + """ + Use the native UPSERT functionality in recent PostgreSQL versions. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + Returns: + None + """ + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(insertion_values) + + if not values: + latter = "NOTHING" + else: + allvalues.update(values) + latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) + + sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ", ".join(k for k in keyvalues), + latter, + ) + txn.execute(sql, list(allvalues.values())) + + def simple_upsert_many_txn( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self.simple_upsert_many_txn_native_upsert( + txn, table, key_names, key_values, value_names, value_values + ) + else: + return self.simple_upsert_many_txn_emulated( + txn, table, key_names, key_values, value_names, value_values + ) + + def simple_upsert_many_txn_emulated( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, but without native UPSERT support or batching. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not value_names: + value_values = [() for x in range(len(key_values))] + + for keyv, valv in zip(key_values, value_values): + _keys = {x: y for x, y in zip(key_names, keyv)} + _vals = {x: y for x, y in zip(value_names, valv)} + + self.simple_upsert_txn_emulated(txn, table, _keys, _vals) + + def simple_upsert_many_txn_native_upsert( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, using batching where possible. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + allnames = [] + allnames.extend(key_names) + allnames.extend(value_names) + + if not value_names: + # No value columns, therefore make a blank list so that the + # following zip() works correctly. + latter = "NOTHING" + value_values = [() for x in range(len(key_values))] + else: + latter = "UPDATE SET " + ", ".join( + k + "=EXCLUDED." + k for k in value_names + ) + + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join("?" for _ in allnames), + ", ".join(key_names), + latter, + ) + + args = [] + + for x, y in zip(key_values, value_values): + args.append(tuple(x) + tuple(y)) + + return txn.execute_batch(sql, args) + + def simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning multiple columns from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcols : list of strings giving the names of the columns to return + + allow_none : If true, return None instead of failing if the SELECT + statement returns no rows + """ + return self.runInteraction( + desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none + ) + + def simple_select_one_onecol( + self, + table, + keyvalues, + retcol, + allow_none=False, + desc="simple_select_one_onecol", + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning a single column from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcol : string giving the name of the column to return + """ + return self.runInteraction( + desc, + self.simple_select_one_onecol_txn, + table, + keyvalues, + retcol, + allow_none=allow_none, + ) + + @classmethod + def simple_select_one_onecol_txn( + cls, txn, table, keyvalues, retcol, allow_none=False + ): + ret = cls.simple_select_onecol_txn( + txn, table=table, keyvalues=keyvalues, retcol=retcol + ) + + if ret: + return ret[0] + else: + if allow_none: + return None + else: + raise StoreError(404, "No row found") + + @staticmethod + def simple_select_onecol_txn(txn, table, keyvalues, retcol): + sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} + + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + txn.execute(sql, list(keyvalues.values())) + else: + txn.execute(sql) + + return [r[0] for r in txn] + + def simple_select_onecol( + self, table, keyvalues, retcol, desc="simple_select_onecol" + ): + """Executes a SELECT query on the named table, which returns a list + comprising of the values of the named column from the selected rows. + + Args: + table (str): table name + keyvalues (dict|None): column names and values to select the rows with + retcol (str): column whos value we wish to retrieve. + + Returns: + Deferred: Results in a list + """ + return self.runInteraction( + desc, self.simple_select_onecol_txn, table, keyvalues, retcol + ) + + def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, self.simple_select_list_txn, table, keyvalues, retcols + ) + + @classmethod + def simple_select_list_txn(cls, txn, table, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + """ + if keyvalues: + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + txn.execute(sql, list(keyvalues.values())) + else: + sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + txn.execute(sql) + + return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def simple_select_many_batch( + self, + table, + column, + iterable, + retcols, + keyvalues={}, + desc="simple_select_many_batch", + batch_size=100, + ): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + results = [] + + if not iterable: + return results + + # iterables can not be sliced, so convert it to a list first + it_list = list(iterable) + + chunks = [ + it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) + ] + for chunk in chunks: + rows = yield self.runInteraction( + desc, + self.simple_select_many_txn, + table, + column, + chunk, + keyvalues, + retcols, + ) + + results.extend(rows) + + return results + + @classmethod + def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + if not iterable: + return [] + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join(clauses), + ) + + txn.execute(sql, values) + return cls.cursor_to_dict(txn) + + def simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, self.simple_update_txn, table, keyvalues, updatevalues + ) + + @staticmethod + def simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) + + return txn.rowcount + + def simple_update_one( + self, table, keyvalues, updatevalues, desc="simple_update_one" + ): + """Executes an UPDATE query on the named table, setting new values for + columns in a row matching the key values. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + updatevalues : dict giving column names and values to update + retcols : optional list of column names to return + + If present, retcols gives a list of column names on which to perform + a SELECT statement *before* performing the UPDATE statement. The values + of these will be returned in a dict. + + These are performed within the same transaction, allowing an atomic + get-and-set. This can be used to implement compare-and-set by putting + the update column in the 'keyvalues' dict as well. + """ + return self.runInteraction( + desc, self.simple_update_one_txn, table, keyvalues, updatevalues + ) + + @classmethod + def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) + + if rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + @staticmethod + def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + select_sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(select_sql, list(keyvalues.values())) + row = txn.fetchone() + + if not row: + if allow_none: + return None + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + return dict(zip(retcols, row)) + + def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + + @staticmethod + def simple_delete_one_txn(txn, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + if txn.rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + def simple_delete(self, table, keyvalues, desc): + return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + + @staticmethod + def simple_delete_txn(txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + return txn.rowcount + + def simple_delete_many(self, table, column, iterable, keyvalues, desc): + return self.runInteraction( + desc, self.simple_delete_many_txn, table, column, iterable, keyvalues + ) + + @staticmethod + def simple_delete_many_txn(txn, table, column, iterable, keyvalues): + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + + Returns: + int: Number rows deleted + """ + if not iterable: + return 0 + + sql = "DELETE FROM %s" % table + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + if clauses: + sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + txn.execute(sql, values) + + return txn.rowcount + + def get_cache_dict( + self, db_conn, table, entity_column, stream_column, max_value, limit=100000 + ): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - %(limit)s" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + "limit": limit, + } + + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + + cache = {row[0]: int(row[1]) for row in txn} + + txn.close() + + if cache: + min_val = min(itervalues(cache)) + else: + min_val = max_value + + return cache, min_val + + def simple_select_list_paginate( + self, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + desc="simple_select_list_paginate", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, + self.simple_select_list_paginate_txn, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction=order_direction, + ) + + @classmethod + def simple_select_list_paginate_txn( + cls, + txn, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + if order_direction not in ["ASC", "DESC"]: + raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") + + if keyvalues: + where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) + else: + where_clause = "" + + sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( + ", ".join(retcols), + table, + where_clause, + orderby, + order_direction, + ) + txn.execute(sql, list(keyvalues.values()) + [limit, start]) + + return cls.cursor_to_dict(txn) + + def get_user_count_txn(self, txn): + """Get a total number of registered users in the users list. + + Args: + txn : Transaction object + Returns: + int : number of users + """ + sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" + txn.execute(sql_count) + return txn.fetchone()[0] + + def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + + return self.runInteraction( + desc, self.simple_search_list_txn, table, term, col, retcols + ) + + @classmethod + def simple_search_list_txn(cls, txn, table, term, col, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + if term: + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) + termvalues = ["%%" + term + "%%"] + txn.execute(sql, termvalues) + else: + return 0 + + return cls.cursor_to_dict(txn) + + +def make_in_list_sql_clause( + database_engine, column: str, iterable: Iterable +) -> Tuple[str, Iterable]: + """Returns an SQL clause that checks the given column is in the iterable. + + On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres + it expands to `column = ANY(?)`. While both DBs support the `IN` form, + using the `ANY` form on postgres means that it views queries with + different length iterables as the same, helping the query stats. + + Args: + database_engine + column: Name of the column + iterable: The values to check the column against. + + Returns: + A tuple of SQL query and the args + """ + + if database_engine.supports_using_any_list: + # This should hopefully be faster, but also makes postgres query + # stats easier to understand. + return "%s = ANY(?)" % (column,), [list(iterable)] + else: + return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 380fd0d107..7f7962c3dd 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -45,13 +45,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def get_all_room_state(self): - return self.store.simple_select_list( + return self.store.db.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store.simple_select_one( + self.store.db.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -180,7 +180,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.handler.stats_enabled = True self.store._all_done = False self.get_success( - self.store.simple_update_one( + self.store.db.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -188,7 +188,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) @@ -205,13 +205,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -656,12 +656,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store.simple_delete( + self.store.db.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store.simple_delete( + self.store.db.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -675,7 +675,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -685,7 +685,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -695,7 +695,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index d5b1c5b4ac..bc9d441541 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -158,7 +158,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -184,7 +184,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -193,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -203,7 +203,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -213,7 +213,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 124ce0768a..0ed2594381 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -632,7 +632,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "state_groups_state", ): count = self.get_success( - self.store.simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 7b7434a468..d491ea2924 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["hello"], ["there"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage.simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage.simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["bleb"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage.simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage.simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 9fabe3fbc0..e360297df9 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -37,7 +37,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): def update(progress, count): self.clock.advance_time_msec(count * duration_ms) progress = {"my_key": progress["my_key"] + 1} - yield self.store.runInteraction( + yield self.store.db.runInteraction( "update_progress", self.store._background_update_progress_txn, "test_update", diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index de5e4a5fce..7915d48a9e 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -65,7 +65,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.simple_insert( + yield self.datastore.db.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -77,7 +77,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.simple_insert( + yield self.datastore.db.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.simple_select_one_onecol( + value = yield self.datastore.db.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -106,7 +106,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -122,7 +122,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -137,7 +137,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.simple_select_list( + ret = yield self.datastore.db.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -150,7 +150,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -165,7 +165,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore.simple_delete_one( + yield self.datastore.db.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 69dcaa63d5..e454bbff29 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -62,7 +62,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.runInteraction("test_delete_forward_extremities", run_delta_file) + self.store.db.runInteraction( + "test_delete_forward_extremities", run_delta_file + ) ) # Ugh, have to reset this flag diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 25bdd2c163..c4f838907c 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -81,7 +81,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -112,7 +112,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -218,7 +218,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But clear the associated entry in devices table self.get_success( - self.store.simple_update( + self.store.db.simple_update( table="devices", keyvalues={"user_id": user_id, "device_id": "device_id"}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, @@ -245,7 +245,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -297,7 +297,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -323,7 +323,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should get no results. result = self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 2fe50377f8..eadfb90a22 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -61,7 +61,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 11): - yield self.store.runInteraction("insert", insert_event, i) + yield self.store.db.runInteraction("insert", insert_event, i) # this should get the last five and five others r = yield self.store.get_prev_events_for_room(room_id) @@ -93,9 +93,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 20): - yield self.store.runInteraction("insert", insert_event, i, room1) - yield self.store.runInteraction("insert", insert_event, i, room2) - yield self.store.runInteraction("insert", insert_event, i, room3) + yield self.store.db.runInteraction("insert", insert_event, i, room1) + yield self.store.db.runInteraction("insert", insert_event, i, room2) + yield self.store.db.runInteraction("insert", insert_event, i, room3) # Test simple case r = yield self.store.get_rooms_with_many_extremities(5, 5, []) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 2337a1ae46..d4bcf1821e 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -55,7 +55,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.runInteraction( + counts = yield self.store.db.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( @@ -74,7 +74,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield self.store.add_push_actions_to_staging( event.event_id, {user_id: action} ) - yield self.store.runInteraction( + yield self.store.db.runInteraction( "", self.store._set_push_actions_for_event_and_users_txn, [(event, None)], @@ -82,12 +82,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) def _rotate(stream): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) def _mark_read(stream, depth): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._remove_old_push_actions_before_txn, room_id, @@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store.simple_delete( + yield self.store.db.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.simple_insert( + return self.store.db.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 90a63dc477..3c78faab45 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -65,7 +65,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.pump() @@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) count = self.store.get_monthly_active_count() @@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 4930b6777e..dc45173355 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -338,7 +338,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store.simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -356,7 +356,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store.simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index d389cf578f..5f957680a2 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -132,7 +132,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", diff --git a/tests/unittest.py b/tests/unittest.py index 295573bc46..fc856a574a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -544,7 +544,7 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().simple_insert( + self.hs.get_datastore().db.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", -- cgit 1.5.1 From 4a33a6dd19590b8e6626a5af5a69507dc11236f8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Dec 2019 15:09:36 +0000 Subject: Move background update handling out of store --- synapse/app/homeserver.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/storage/background_updates.py | 15 +++---- synapse/storage/data_stores/main/client_ips.py | 36 ++++++++-------- synapse/storage/data_stores/main/deviceinbox.py | 9 ++-- synapse/storage/data_stores/main/devices.py | 15 ++++--- .../storage/data_stores/main/event_federation.py | 6 +-- .../storage/data_stores/main/event_push_actions.py | 4 +- synapse/storage/data_stores/main/events.py | 6 +-- .../storage/data_stores/main/events_bg_updates.py | 49 +++++++++++---------- .../storage/data_stores/main/media_repository.py | 6 +-- synapse/storage/data_stores/main/registration.py | 21 ++++----- synapse/storage/data_stores/main/room.py | 9 ++-- synapse/storage/data_stores/main/roommember.py | 27 +++++++----- synapse/storage/data_stores/main/search.py | 31 ++++++++------ synapse/storage/data_stores/main/state.py | 19 ++++---- synapse/storage/data_stores/main/stats.py | 20 ++++----- synapse/storage/data_stores/main/user_directory.py | 33 ++++++++------ synapse/storage/database.py | 3 ++ synmark/__init__.py | 6 +-- tests/handlers/test_stats.py | 50 +++++++++++++++------- tests/handlers/test_user_directory.py | 18 +++++--- tests/storage/test_background_update.py | 26 +++++++---- tests/storage/test_cleanup_extrems.py | 14 ++++-- tests/storage/test_client_ips.py | 26 ++++++++--- tests/storage/test_roommember.py | 18 +++++--- tests/unittest.py | 10 +++-- 27 files changed, 281 insertions(+), 200 deletions(-) (limited to 'tests') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 267aebaae9..9f81a857ab 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -436,7 +436,7 @@ def setup(config_options): _base.start(hs, config.listeners) hs.get_pusherpool().start() - hs.get_datastore().start_doing_background_updates() + hs.get_datastore().db.updates.start_doing_background_updates() except Exception: # Print the exception and bail out. print("Error during startup:", file=sys.stderr) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index fb0d02aa83..6b978be876 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -402,7 +402,7 @@ class PreviewUrlResource(DirectServeResource): logger.info("Running url preview cache expiry") - if not (yield self.store.has_completed_background_updates()): + if not (yield self.store.db.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index dfca94b0e0..a9a13a2658 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -22,7 +22,6 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from . import engines -from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -74,7 +73,7 @@ class BackgroundUpdatePerformance(object): return float(self.total_item_count) / float(self.total_duration_ms) -class BackgroundUpdateStore(SQLBaseStore): +class BackgroundUpdater(object): """ Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to limit the impact of each update by monitoring how long each batch takes to @@ -86,8 +85,10 @@ class BackgroundUpdateStore(SQLBaseStore): BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, db_conn, hs): - super(BackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, hs, database): + self._clock = hs.get_clock() + self.db = database + self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} @@ -101,9 +102,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.info("Starting background schema updates") while True: if sleep: - yield self.hs.get_clock().sleep( - self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0 - ) + yield self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) try: result = yield self.do_next_background_update( @@ -380,7 +379,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.database_engine, engines.PostgresEngine): + if isinstance(self.db.database_engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 6f2a720b97..7b470a58f1 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -20,7 +20,7 @@ from six import iteritems from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage import background_updates +from synapse.storage._base import SQLBaseStore from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -32,41 +32,41 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 -class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): +class ClientIpBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_index", index_name="user_ips_device_id", table="user_ips", columns=["user_id", "device_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_index", index_name="user_ips_last_seen", table="user_ips", columns=["user_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_only_index", index_name="user_ips_last_seen_only", table="user_ips", columns=["last_seen"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_analyze", self._analyze_user_ip ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_remove_dupes", self._remove_user_ip_dupes ) # Register a unique index - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_unique_index", index_name="user_ips_user_token_ip_unique_index", table="user_ips", @@ -75,12 +75,12 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): ) # Drop the old non-unique index - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique ) # Update the last seen info in devices. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "devices_last_seen", self._devices_last_seen_update ) @@ -92,7 +92,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn.close() yield self.db.runWithConnection(f) - yield self._end_background_update("user_ips_drop_nonunique_index") + yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") return 1 @defer.inlineCallbacks @@ -108,7 +108,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) - yield self._end_background_update("user_ips_analyze") + yield self.db.updates._end_background_update("user_ips_analyze") return 1 @@ -271,14 +271,14 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): (user_id, access_token, ip, device_id, user_agent, last_seen), ) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) yield self.db.runInteraction("user_ips_dups_remove", remove) if last: - yield self._end_background_update("user_ips_remove_dupes") + yield self.db.updates._end_background_update("user_ips_remove_dupes") return batch_size @@ -344,7 +344,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn.execute_batch(sql, rows) _, _, _, user_id, device_id = rows[-1] - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "devices_last_seen", {"last_user_id": user_id, "last_device_id": device_id}, @@ -357,7 +357,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): ) if not updated: - yield self._end_background_update("devices_last_seen") + yield self.db.updates._end_background_update("devices_last_seen") return updated @@ -546,7 +546,9 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Nothing to do return - if not await self.has_completed_background_update("devices_last_seen"): + if not await self.db.updates.has_completed_background_update( + "devices_last_seen" + ): # Only start pruning if we have finished populating the devices # last seen info. return diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 440793ad49..3c9f09301a 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,7 +21,6 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -208,20 +207,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) -class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, db_conn, hs): super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_inbox_stream_index", index_name="device_inbox_stream_id_user_id", table="device_inbox", columns=["stream_id", "user_id"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) @@ -234,7 +233,7 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): yield self.db.runWithConnection(reindex_txn) - yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) return 1 diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index d98511ddd4..91ddaf137e 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -31,7 +31,6 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter from synapse.util.caches.descriptors import ( @@ -642,11 +641,11 @@ class DeviceWorkerStore(SQLBaseStore): return results -class DeviceBackgroundUpdateStore(BackgroundUpdateStore): +class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_stream_idx", index_name="device_lists_stream_user_id", table="device_lists_stream", @@ -654,7 +653,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # create a unique index on device_lists_remote_cache - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_cache_unique_idx", index_name="device_lists_remote_cache_unique_id", table="device_lists_remote_cache", @@ -663,7 +662,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # And one on device_lists_remote_extremeties - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_extremeties_unique_idx", index_name="device_lists_remote_extremeties_unique_idx", table="device_lists_remote_extremeties", @@ -672,7 +671,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # once they complete, we can remove the old non-unique indexes. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, self._drop_device_list_streams_non_unique_indexes, ) @@ -686,7 +685,9 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): txn.close() yield self.db.runWithConnection(f) - yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) + yield self.db.updates._end_background_update( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES + ) return 1 diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 77e4353b59..31d2e8eb28 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -494,7 +494,7 @@ class EventFederationStore(EventFederationWorkerStore): def __init__(self, db_conn, hs): super(EventFederationStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) @@ -654,7 +654,7 @@ class EventFederationStore(EventFederationWorkerStore): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_AUTH_STATE_ONLY, new_progress ) @@ -665,6 +665,6 @@ class EventFederationStore(EventFederationWorkerStore): ) if not result: - yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) return batch_size diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 725d0881dc..eec054cd48 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -614,14 +614,14 @@ class EventPushActionsStore(EventPushActionsWorkerStore): def __init__(self, db_conn, hs): super(EventPushActionsStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, index_name="event_push_actions_u_highlight", table="event_push_actions", columns=["user_id", "stream_ordering"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_push_actions_highlights_index", index_name="event_push_actions_highlights_index", table="event_push_actions", diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 01ec9ec397..d644c82784 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -38,7 +38,6 @@ from synapse.logging.utils import log_function from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.event_federation import EventFederationStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore @@ -94,10 +93,7 @@ def _retry_on_integrity_error(func): # inherits from EventFederationStore so that we can call _update_backward_extremities # and _handle_mult_prev_events (though arguably those could both be moved in here) class EventsStore( - StateGroupWorkerStore, - EventFederationStore, - EventsWorkerStore, - BackgroundUpdateStore, + StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, ): def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 365e966956..cb1fc30c31 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -22,13 +22,12 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventContentFields -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause logger = logging.getLogger(__name__) -class EventsBackgroundUpdatesStore(BackgroundUpdateStore): +class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" @@ -37,15 +36,15 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): def __init__(self, db_conn, hs): super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_contains_url_index", index_name="event_contains_url_index", table="events", @@ -56,7 +55,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): # an event_id index on event_search is useful for the purge_history # api. Plus it means we get to enforce some integrity with a UNIQUE # clause - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_search_event_id_idx", index_name="event_search_event_id_idx", table="event_search", @@ -65,16 +64,16 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): psql_only=True, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "redactions_received_ts", self._redactions_received_ts ) # This index gets deleted in `event_fix_redactions_bytes` update - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_fix_redactions_bytes_create_index", index_name="redactions_censored_redacts", table="redactions", @@ -82,11 +81,11 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): where_clause="have_censored", ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "event_fix_redactions_bytes", self._event_fix_redactions_bytes ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "event_store_labels", self._event_store_labels ) @@ -145,7 +144,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) @@ -156,7 +155,9 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME + ) return result @@ -222,7 +223,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(rows_to_update), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress ) @@ -233,7 +234,9 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + yield self.db.updates._end_background_update( + self.EVENT_ORIGIN_SERVER_TS_NAME + ) return result @@ -411,7 +414,9 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not num_handled: - yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) + yield self.db.updates._end_background_update( + self.DELETE_SOFT_FAILED_EXTREMITIES + ) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") @@ -464,7 +469,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "redactions_received_ts", {"last_event_id": upper_event_id} ) @@ -475,7 +480,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not count: - yield self._end_background_update("redactions_received_ts") + yield self.db.updates._end_background_update("redactions_received_ts") return count @@ -505,7 +510,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) - yield self._end_background_update("event_fix_redactions_bytes") + yield self.db.updates._end_background_update("event_fix_redactions_bytes") return 1 @@ -559,7 +564,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): nbrows += 1 last_row_event_id = event_id - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "event_store_labels", {"last_event_id": last_row_event_id} ) @@ -570,6 +575,6 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not num_rows: - yield self._end_background_update("event_store_labels") + yield self.db.updates._end_background_update("event_store_labels") return num_rows diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index ea02497784..03c9c6f8ae 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -12,14 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore -class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore): +class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( update_name="local_media_repository_url_idx", index_name="local_media_repository_url_idx", table="local_media_repository", diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 8f9aa87ceb..1ef143c6d8 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -26,7 +26,6 @@ from twisted.internet.defer import Deferred from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -794,23 +793,21 @@ class RegistrationWorkerStore(SQLBaseStore): ) -class RegistrationBackgroundUpdateStore( - RegistrationWorkerStore, background_updates.BackgroundUpdateStore -): +class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, db_conn, hs): super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) self.clock = hs.get_clock() self.config = hs.config - self.register_background_index_update( + self.db.updates.register_background_index_update( "access_tokens_device_index", index_name="access_tokens_device_id", table="access_tokens", columns=["user_id", "device_id"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "users_creation_ts", index_name="users_creation_ts", table="users", @@ -820,13 +817,13 @@ class RegistrationBackgroundUpdateStore( # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. - self.register_noop_background_update("refresh_tokens_device_index") + self.db.updates.register_noop_background_update("refresh_tokens_device_index") - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_threepids_grandfather", self._bg_user_threepids_grandfather ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) @@ -873,7 +870,7 @@ class RegistrationBackgroundUpdateStore( logger.info("Marked %d rows as deactivated", rows_processed_nb) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} ) @@ -887,7 +884,7 @@ class RegistrationBackgroundUpdateStore( ) if end: - yield self._end_background_update("users_set_deactivated_flag") + yield self.db.updates._end_background_update("users_set_deactivated_flag") return nb_processed @@ -917,7 +914,7 @@ class RegistrationBackgroundUpdateStore( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) - yield self._end_background_update("user_threepids_grandfather") + yield self.db.updates._end_background_update("user_threepids_grandfather") return 1 diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index a26ed47afc..da42dae243 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -28,7 +28,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.search import SearchStore from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -361,13 +360,13 @@ class RoomWorkerStore(SQLBaseStore): defer.returnValue(row) -class RoomBackgroundUpdateStore(BackgroundUpdateStore): +class RoomBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs) self.config = hs.config - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "insert_room_retention", self._background_insert_retention, ) @@ -421,7 +420,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): logger.info("Inserted %d rows into room_retention", len(rows)) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} ) @@ -435,7 +434,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): ) if end: - yield self._end_background_update("insert_room_retention") + yield self.db.updates._end_background_update("insert_room_retention") defer.returnValue(batch_size) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 7f4d02b25b..929f6b0d39 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -26,8 +26,11 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import ( + LoggingTransaction, + SQLBaseStore, + make_in_list_sql_clause, +) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( @@ -831,17 +834,17 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) -class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): +class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, self._background_current_state_membership, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "room_membership_forgotten_idx", index_name="room_memberships_user_room_forgotten", table="room_memberships", @@ -909,7 +912,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress ) @@ -920,7 +923,9 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) + yield self.db.updates._end_background_update( + _MEMBERSHIP_PROFILE_UPDATE_NAME + ) return result @@ -959,7 +964,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): last_processed_room = next_room - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, {"last_processed_room": last_processed_room}, @@ -978,7 +983,9 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): ) if finished: - yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) + yield self.db.updates._end_background_update( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME + ) return row_count diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 55a604850e..ffa1817e64 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -24,8 +24,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -36,7 +35,7 @@ SearchEntry = namedtuple( ) -class SearchBackgroundUpdateStore(BackgroundUpdateStore): +class SearchBackgroundUpdateStore(SQLBaseStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" @@ -49,10 +48,10 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): if not hs.config.enable_search: return - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) @@ -61,9 +60,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): # a GIN index. However, it's possible that some people might still have # the background update queued, so we register a handler to clear the # background update. - self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + self.db.updates.register_noop_background_update( + self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME + ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) @@ -153,7 +154,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(event_search_rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_UPDATE_NAME, progress ) @@ -164,7 +165,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) return result @@ -208,7 +209,9 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): yield self.db.runWithConnection(create_index) - yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME + ) return 1 @defer.inlineCallbacks @@ -244,7 +247,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, ) @@ -274,7 +277,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): "have_added_indexes": True, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress ) @@ -285,7 +288,9 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): ) if not finished: - yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_ORDER_UPDATE_NAME + ) return num_rows diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 851e81d6b3..7d5a9f8128 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -27,7 +27,6 @@ from synapse.api.errors import NotFoundError from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter @@ -1023,9 +1022,7 @@ class StateGroupWorkerStore( return set(row["state_group"] for row in rows) -class StateBackgroundUpdateStore( - StateGroupBackgroundUpdateStore, BackgroundUpdateStore -): +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" @@ -1034,21 +1031,21 @@ class StateBackgroundUpdateStore( def __init__(self, db_conn, hs): super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state ) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", table="current_state_events", columns=["state_key"], where_clause="type='m.room.member'", ) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, index_name="event_to_state_groups_sg_index", table="event_to_state_groups", @@ -1181,7 +1178,7 @@ class StateBackgroundUpdateStore( "max_group": max_group, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress ) @@ -1192,7 +1189,7 @@ class StateBackgroundUpdateStore( ) if finished: - yield self._end_background_update( + yield self.db.updates._end_background_update( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) @@ -1224,7 +1221,7 @@ class StateBackgroundUpdateStore( yield self.db.runWithConnection(reindex_txn) - yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) return 1 diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 974ffc15bd..6b91988c2a 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -68,17 +68,17 @@ class StatsStore(StateDeltasStore): self.stats_delta_processing_lock = DeferredLock() - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) # we no longer need to perform clean-up, but we will give ourselves # the potential to reintroduce it in the future – so documentation # will still encourage the use of this no-op handler. - self.register_noop_background_update("populate_stats_cleanup") - self.register_noop_background_update("populate_stats_prepare") + self.db.updates.register_noop_background_update("populate_stats_cleanup") + self.db.updates.register_noop_background_update("populate_stats_prepare") def quantise_stats_time(self, ts): """ @@ -102,7 +102,7 @@ class StatsStore(StateDeltasStore): This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 last_user_id = progress.get("last_user_id", "") @@ -123,7 +123,7 @@ class StatsStore(StateDeltasStore): # No more rooms -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 for user_id in users_to_work_on: @@ -132,7 +132,7 @@ class StatsStore(StateDeltasStore): yield self.db.runInteraction( "populate_stats_process_users", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_users", progress, ) @@ -145,7 +145,7 @@ class StatsStore(StateDeltasStore): This is a background update which regenerates statistics for rooms. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 last_room_id = progress.get("last_room_id", "") @@ -166,7 +166,7 @@ class StatsStore(StateDeltasStore): # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 for room_id in rooms_to_work_on: @@ -175,7 +175,7 @@ class StatsStore(StateDeltasStore): yield self.db.runInteraction( "_populate_stats_process_rooms", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_rooms", progress, ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 7118bd62f3..62ffb34b29 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -19,7 +19,6 @@ import re from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.state import StateFilter from synapse.storage.data_stores.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -32,7 +31,7 @@ logger = logging.getLogger(__name__) TEMP_TABLE = "_temp_populate_user_directory" -class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore): +class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? @@ -43,19 +42,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore self.server_name = hs.hostname - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_createtables", self._populate_user_directory_createtables, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_rooms", self._populate_user_directory_process_rooms, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_users", self._populate_user_directory_process_users, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) @@ -108,7 +107,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) - yield self._end_background_update("populate_user_directory_createtables") + yield self.db.updates._end_background_update( + "populate_user_directory_createtables" + ) return 1 @defer.inlineCallbacks @@ -130,7 +131,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore "populate_user_directory_cleanup", _delete_staging_area ) - yield self._end_background_update("populate_user_directory_cleanup") + yield self.db.updates._end_background_update("populate_user_directory_cleanup") return 1 @defer.inlineCallbacks @@ -176,7 +177,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_user_directory_process_rooms") + yield self.db.updates._end_background_update( + "populate_user_directory_process_rooms" + ) return 1 logger.info( @@ -248,7 +251,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore progress["remaining"] -= 1 yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_rooms", progress, ) @@ -267,7 +270,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore If search_all_users is enabled, add all of the users to the user directory. """ if not self.hs.config.user_directory_search_all_users: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 def _get_next_batch(txn): @@ -297,7 +302,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore # No more users -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 logger.info( @@ -317,7 +324,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore progress["remaining"] -= 1 yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_users", progress, ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ac64d80806..be36c1b829 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -30,6 +30,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.util.stringutils import exception_to_unicode @@ -223,6 +224,8 @@ class Database(object): self._clock = hs.get_clock() self._db_pool = hs.get_db_pool() + self.updates = BackgroundUpdater(hs, self) + self._previous_txn_total_time = 0 self._current_txn_total_time = 0 self._previous_loop_ts = 0 diff --git a/synmark/__init__.py b/synmark/__init__.py index 570eb818d9..afe4fad8cb 100644 --- a/synmark/__init__.py +++ b/synmark/__init__.py @@ -47,9 +47,9 @@ async def make_homeserver(reactor, config=None): stor = hs.get_datastore() # Run the database background updates. - if hasattr(stor, "do_next_background_update"): - while not await stor.has_completed_background_updates(): - await stor.do_next_background_update(1) + if hasattr(stor.db.updates, "do_next_background_update"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) def cleanup(): for i in cleanup_tasks: diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 7f7962c3dd..d9d312f0fb 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -42,7 +42,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_insert( @@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the stats via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_initial_room(self): """ @@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r = self.get_success(self.get_all_room_state()) @@ -178,7 +186,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_update_one( table="stats_incremental_position", @@ -194,8 +202,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now, before the table is actually ingested, add some more events. self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) @@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - self.store._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + self.store.db.updates._all_done = False + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) self.reactor.advance(86401) @@ -653,7 +669,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_delete( @@ -673,7 +689,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # now do the background updates - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_insert( "background_updates", @@ -705,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r1stats_complete = self._get_current_stats("room", r1) u1stats_complete = self._get_current_stats("user", u1) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index bc9d441541..26071059d2 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -181,7 +181,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_insert( @@ -255,8 +255,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() @@ -290,8 +294,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index e360297df9..aec76f4ab1 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -15,7 +15,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler = Mock() - yield self.store.register_background_update_handler( + yield self.store.db.updates.register_background_update_handler( "test_update", self.update_handler ) @@ -23,7 +23,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): # (perhaps we should run them as part of the test HS setup, since we # run all of the other schema setup stuff there?) while True: - res = yield self.store.do_next_background_update(1000) + res = yield self.store.db.updates.do_next_background_update(1000) if res is None: break @@ -39,7 +39,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): progress = {"my_key": progress["my_key"] + 1} yield self.store.db.runInteraction( "update_progress", - self.store._background_update_progress_txn, + self.store.db.updates._background_update_progress_txn, "test_update", progress, ) @@ -47,29 +47,37 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler.side_effect = update - yield self.store.start_background_update("test_update", {"my_key": 1}) + yield self.store.db.updates.start_background_update( + "test_update", {"my_key": 1} + ) self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with( - {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update @defer.inlineCallbacks def update(progress, count): - yield self.store._end_background_update("test_update") + yield self.store.db.updates._end_background_update("test_update") return count self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNone(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index e454bbff29..029ac26454 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -46,7 +46,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """Re run the background update to clean up the extremities. """ # Make sure we don't clash with in progress updates. - self.assertTrue(self.store._all_done, "Background updates are still ongoing") + self.assertTrue( + self.store.db.updates._all_done, "Background updates are still ongoing" + ) schema_path = os.path.join( prepare_database.dir_path, @@ -68,10 +70,14 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): ) # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c4f838907c..fc279340d4 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -202,8 +202,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Insert a user IP user_id = "@user:id" @@ -256,11 +260,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # We should now get the correct result again result = self.get_success( @@ -281,8 +289,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Insert a user IP user_id = "@user:id" diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5f957680a2..7840f63fe3 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -122,8 +122,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now let's create a room, which will insert a membership user = UserID("alice", "test") @@ -143,8 +147,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) diff --git a/tests/unittest.py b/tests/unittest.py index fc856a574a..68d245ec9f 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -401,10 +401,12 @@ class HomeserverTestCase(TestCase): hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() - # Run the database background updates. - if hasattr(stor, "do_next_background_update"): - while not self.get_success(stor.has_completed_background_updates()): - self.get_success(stor.do_next_background_update(1)) + # Run the database background updates, when running against "master". + if hs.__class__.__name__ == "TestHomeServer": + while not self.get_success( + stor.db.updates.has_completed_background_updates() + ): + self.get_success(stor.db.updates.do_next_background_update(1)) return hs -- cgit 1.5.1 From 4ca3ef10b9a8d15cf351d67d574088d944c2e3b1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Dec 2019 15:53:10 +0000 Subject: Fixup tests --- tests/rest/client/v1/test_presence.py | 3 +++ tests/rest/client/v1/test_profile.py | 10 +++++++++- tests/utils.py | 4 +++- 3 files changed, 15 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 66c2b68707..0fdff79aa7 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -15,6 +15,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.rest.client.v1 import presence from synapse.types import UserID @@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) hs.presence_handler = Mock() + hs.presence_handler.set_state.return_value = defer.succeed(None) return hs diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 140d8b3772..12c5e95cb5 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase): ] ) + self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( + Mock() + ) + hs = yield setup_test_homeserver( self.addCleanup, "test", @@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) + return defer.succeed(synapse.types.create_requester(myid)) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/utils.py b/tests/utils.py index de2ac1ed33..c57da59191 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -461,7 +461,9 @@ class MockHttpResource(HttpServer): try: args = [urlparse.unquote(u) for u in matcher.groups()] - (code, response) = yield func(mock_request, *args) + (code, response) = yield defer.ensureDeferred( + func(mock_request, *args) + ) return code, response except CodeMessageException as e: return (e.code, cs_error(e.msg, code=e.errcode)) -- cgit 1.5.1 From 8437e2383ed2dffacca5395851023eeacb33d7ba Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Dec 2019 17:58:25 +0000 Subject: Port SyncHandler to async/await --- synapse/handlers/events.py | 30 +++--- synapse/handlers/sync.py | 251 +++++++++++++++++++++----------------------- synapse/notifier.py | 29 +++-- synapse/util/metrics.py | 23 ++-- tests/handlers/test_sync.py | 33 +++--- tests/unittest.py | 7 +- 6 files changed, 182 insertions(+), 191 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 45fe13c62f..ec18a42a68 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -16,8 +16,6 @@ import logging import random -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase @@ -50,9 +48,8 @@ class EventStreamHandler(BaseHandler): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks @log_function - def get_stream( + async def get_stream( self, auth_user_id, pagin_config, @@ -69,17 +66,17 @@ class EventStreamHandler(BaseHandler): """ if room_id: - blocked = yield self.store.is_room_blocked(room_id) + blocked = await self.store.is_room_blocked(room_id) if blocked: raise SynapseError(403, "This room has been blocked on this server") # send any outstanding server notices to the user. - yield self._server_notices_sender.on_user_syncing(auth_user_id) + await self._server_notices_sender.on_user_syncing(auth_user_id) auth_user = UserID.from_string(auth_user_id) presence_handler = self.hs.get_presence_handler() - context = yield presence_handler.user_syncing( + context = await presence_handler.user_syncing( auth_user_id, affect_presence=affect_presence ) with context: @@ -91,7 +88,7 @@ class EventStreamHandler(BaseHandler): # thundering herds on restart. timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) - events, tokens = yield self.notifier.get_events_for( + events, tokens = await self.notifier.get_events_for( auth_user, pagin_config, timeout, @@ -112,14 +109,14 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.state.get_current_users_in_room( + users = await self.state.get_current_users_in_room( event.room_id ) - states = yield presence_handler.get_states(users, as_event=True) + states = await presence_handler.get_states(users, as_event=True) to_add.extend(states) else: - ev = yield presence_handler.get_state( + ev = await presence_handler.get_state( UserID.from_string(event.state_key), as_event=True ) to_add.append(ev) @@ -128,7 +125,7 @@ class EventStreamHandler(BaseHandler): time_now = self.clock.time_msec() - chunks = yield self._event_serializer.serialize_events( + chunks = await self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event, @@ -151,8 +148,7 @@ class EventHandler(BaseHandler): super(EventHandler, self).__init__(hs) self.storage = hs.get_storage() - @defer.inlineCallbacks - def get_event(self, user, room_id, event_id): + async def get_event(self, user, room_id, event_id): """Retrieve a single specified event. Args: @@ -167,15 +163,15 @@ class EventHandler(BaseHandler): AuthError if the user does not have the rights to inspect this event. """ - event = yield self.store.get_event(event_id, check_room_id=room_id) + event = await self.store.get_event(event_id, check_room_id=room_id) if not event: return None - users = yield self.store.get_users_in_room(event.room_id) + users = await self.store.get_users_in_room(event.room_id) is_peeking = user.to_string() not in users - filtered = yield filter_events_for_client( + filtered = await filter_events_for_client( self.storage, user.to_string(), [event], is_peeking=is_peeking ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b536d410e5..12751fd8c0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -22,8 +22,6 @@ from six import iteritems, itervalues from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.logging.context import LoggingContext from synapse.push.clientformat import format_push_rules_for_user @@ -241,8 +239,7 @@ class SyncHandler(object): expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) - @defer.inlineCallbacks - def wait_for_sync_for_user( + async def wait_for_sync_for_user( self, sync_config, since_token=None, timeout=0, full_state=False ): """Get the sync for a client if we have new data for it now. Otherwise @@ -255,9 +252,9 @@ class SyncHandler(object): # not been exceeded (if not part of the group by this point, almost certain # auth_blocking will occur) user_id = sync_config.user.to_string() - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) - res = yield self.response_cache.wrap( + res = await self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, sync_config, @@ -267,8 +264,9 @@ class SyncHandler(object): ) return res - @defer.inlineCallbacks - def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): + async def _wait_for_sync_for_user( + self, sync_config, since_token, timeout, full_state + ): if since_token is None: sync_type = "initial_sync" elif full_state: @@ -283,7 +281,7 @@ class SyncHandler(object): if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. - result = yield self.current_sync_for_user( + result = await self.current_sync_for_user( sync_config, since_token, full_state=full_state ) else: @@ -291,7 +289,7 @@ class SyncHandler(object): def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) - result = yield self.notifier.wait_for_events( + result = await self.notifier.wait_for_events( sync_config.user.to_string(), timeout, current_sync_callback, @@ -314,15 +312,13 @@ class SyncHandler(object): """ return self.generate_sync_result(sync_config, since_token, full_state) - @defer.inlineCallbacks - def push_rules_for_user(self, user): + async def push_rules_for_user(self, user): user_id = user.to_string() - rules = yield self.store.get_push_rules_for_user(user_id) + rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(user, rules) return rules - @defer.inlineCallbacks - def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): + async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): """Get the ephemeral events for each room the user is in Args: sync_result_builder(SyncResultBuilder) @@ -343,7 +339,7 @@ class SyncHandler(object): room_ids = sync_result_builder.joined_room_ids typing_source = self.event_sources.sources["typing"] - typing, typing_key = yield typing_source.get_new_events( + typing, typing_key = typing_source.get_new_events( user=sync_config.user, from_key=typing_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -365,7 +361,7 @@ class SyncHandler(object): receipt_key = since_token.receipt_key if since_token else "0" receipt_source = self.event_sources.sources["receipt"] - receipts, receipt_key = yield receipt_source.get_new_events( + receipts, receipt_key = await receipt_source.get_new_events( user=sync_config.user, from_key=receipt_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -382,8 +378,7 @@ class SyncHandler(object): return now_token, ephemeral_by_room - @defer.inlineCallbacks - def _load_filtered_recents( + async def _load_filtered_recents( self, room_id, sync_config, @@ -415,10 +410,10 @@ class SyncHandler(object): # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in recents): - current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = await self.state.get_current_state_ids(room_id) current_state_ids = frozenset(itervalues(current_state_ids)) - recents = yield filter_events_for_client( + recents = await filter_events_for_client( self.storage, sync_config.user.to_string(), recents, @@ -449,14 +444,14 @@ class SyncHandler(object): # Otherwise, we want to return the last N events in the room # in toplogical ordering. if since_key: - events, end_key = yield self.store.get_room_events_stream_for_room( + events, end_key = await self.store.get_room_events_stream_for_room( room_id, limit=load_limit + 1, from_key=since_key, to_key=end_key, ) else: - events, end_key = yield self.store.get_recent_events_for_room( + events, end_key = await self.store.get_recent_events_for_room( room_id, limit=load_limit + 1, end_token=end_key ) loaded_recents = sync_config.filter_collection.filter_room_timeline( @@ -468,10 +463,10 @@ class SyncHandler(object): # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in loaded_recents): - current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = await self.state.get_current_state_ids(room_id) current_state_ids = frozenset(itervalues(current_state_ids)) - loaded_recents = yield filter_events_for_client( + loaded_recents = await filter_events_for_client( self.storage, sync_config.user.to_string(), loaded_recents, @@ -498,8 +493,7 @@ class SyncHandler(object): limited=limited or newly_joined_room, ) - @defer.inlineCallbacks - def get_state_after_event(self, event, state_filter=StateFilter.all()): + async def get_state_after_event(self, event, state_filter=StateFilter.all()): """ Get the room state after the given event @@ -511,7 +505,7 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter ) if event.is_state(): @@ -519,8 +513,9 @@ class SyncHandler(object): state_ids[(event.type, event.state_key)] = event.event_id return state_ids - @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): + async def get_state_at( + self, room_id, stream_position, state_filter=StateFilter.all() + ): """ Get the room state at a particular stream position Args: @@ -536,13 +531,13 @@ class SyncHandler(object): # get_recent_events_for_room operates by topo ordering. This therefore # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) - last_events, _ = yield self.store.get_recent_events_for_room( + last_events, _ = await self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1 ) if last_events: last_event = last_events[-1] - state = yield self.get_state_after_event( + state = await self.get_state_after_event( last_event, state_filter=state_filter ) @@ -551,8 +546,7 @@ class SyncHandler(object): state = {} return state - @defer.inlineCallbacks - def compute_summary(self, room_id, sync_config, batch, state, now_token): + async def compute_summary(self, room_id, sync_config, batch, state, now_token): """ Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds @@ -574,7 +568,7 @@ class SyncHandler(object): # FIXME: we could/should get this from room_stats when matthew/stats lands # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 - last_events, _ = yield self.store.get_recent_event_ids_for_room( + last_events, _ = await self.store.get_recent_event_ids_for_room( room_id, end_token=now_token.room_key, limit=1 ) @@ -582,7 +576,7 @@ class SyncHandler(object): return None last_event = last_events[-1] - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -590,7 +584,7 @@ class SyncHandler(object): ) # this is heavily cached, thus: fast. - details = yield self.store.get_room_summary(room_id) + details = await self.store.get_room_summary(room_id) name_id = state_ids.get((EventTypes.Name, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) @@ -608,12 +602,12 @@ class SyncHandler(object): # calculating heroes. Empty strings are falsey, so we check # for the "name" value and default to an empty string. if name_id: - name = yield self.store.get_event(name_id, allow_none=True) + name = await self.store.get_event(name_id, allow_none=True) if name and name.content.get("name"): return summary if canonical_alias_id: - canonical_alias = yield self.store.get_event( + canonical_alias = await self.store.get_event( canonical_alias_id, allow_none=True ) if canonical_alias and canonical_alias.content.get("alias"): @@ -678,7 +672,7 @@ class SyncHandler(object): ) ] - missing_hero_state = yield self.store.get_events(missing_hero_event_ids) + missing_hero_state = await self.store.get_events(missing_hero_event_ids) missing_hero_state = missing_hero_state.values() for s in missing_hero_state: @@ -697,8 +691,7 @@ class SyncHandler(object): logger.debug("found LruCache for %r", cache_key) return cache - @defer.inlineCallbacks - def compute_state_delta( + async def compute_state_delta( self, room_id, batch, sync_config, since_token, now_token, full_state ): """ Works out the difference in state between the start of the timeline @@ -759,16 +752,16 @@ class SyncHandler(object): if full_state: if batch: - current_state_ids = yield self.state_store.get_state_ids_for_event( + current_state_ids = await self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: - current_state_ids = yield self.get_state_at( + current_state_ids = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -783,13 +776,13 @@ class SyncHandler(object): ) elif batch.limited: if batch: - state_at_timeline_start = yield self.state_store.get_state_ids_for_event( + state_at_timeline_start = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: # We can get here if the user has ignored the senders of all # the recent events. - state_at_timeline_start = yield self.get_state_at( + state_at_timeline_start = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -807,19 +800,19 @@ class SyncHandler(object): # about them). state_filter = StateFilter.all() - state_at_previous_sync = yield self.get_state_at( + state_at_previous_sync = await self.get_state_at( room_id, stream_position=since_token, state_filter=state_filter ) if batch: - current_state_ids = yield self.state_store.get_state_ids_for_event( + current_state_ids = await self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) else: # Its not clear how we get here, but empirically we do # (#5407). Logging has been added elsewhere to try and # figure out where this state comes from. - current_state_ids = yield self.get_state_at( + current_state_ids = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -843,7 +836,7 @@ class SyncHandler(object): # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( @@ -883,7 +876,7 @@ class SyncHandler(object): state = {} if state_ids: - state = yield self.store.get_events(list(state_ids.values())) + state = await self.store.get_events(list(state_ids.values())) return { (e.type, e.state_key): e @@ -892,10 +885,9 @@ class SyncHandler(object): ) } - @defer.inlineCallbacks - def unread_notifs_for_room_id(self, room_id, sync_config): + async def unread_notifs_for_room_id(self, room_id, sync_config): with Measure(self.clock, "unread_notifs_for_room_id"): - last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( + last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, receipt_type="m.read", @@ -903,7 +895,7 @@ class SyncHandler(object): notifs = [] if last_unread_event_id: - notifs = yield self.store.get_unread_event_push_actions_by_room_for_user( + notifs = await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), last_unread_event_id ) return notifs @@ -912,8 +904,9 @@ class SyncHandler(object): # count is whatever it was last time. return None - @defer.inlineCallbacks - def generate_sync_result(self, sync_config, since_token=None, full_state=False): + async def generate_sync_result( + self, sync_config, since_token=None, full_state=False + ): """Generates a sync result. Args: @@ -928,7 +921,7 @@ class SyncHandler(object): # this is due to some of the underlying streams not supporting the ability # to query up to a given point. # Always use the `now_token` in `SyncResultBuilder` - now_token = yield self.event_sources.get_current_token() + now_token = await self.event_sources.get_current_token() logger.info( "Calculating sync response for %r between %s and %s", @@ -944,10 +937,9 @@ class SyncHandler(object): # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() else: - joined_room_ids = yield self.get_rooms_for_user_at( + joined_room_ids = await self.get_rooms_for_user_at( user_id, now_token.room_stream_id ) - sync_result_builder = SyncResultBuilder( sync_config, full_state, @@ -956,11 +948,11 @@ class SyncHandler(object): joined_room_ids=joined_room_ids, ) - account_data_by_room = yield self._generate_sync_entry_for_account_data( + account_data_by_room = await self._generate_sync_entry_for_account_data( sync_result_builder ) - res = yield self._generate_sync_entry_for_rooms( + res = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) newly_joined_rooms, newly_joined_or_invited_users, _, _ = res @@ -970,13 +962,13 @@ class SyncHandler(object): since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: - yield self._generate_sync_entry_for_presence( + await self._generate_sync_entry_for_presence( sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ) - yield self._generate_sync_entry_for_to_device(sync_result_builder) + await self._generate_sync_entry_for_to_device(sync_result_builder) - device_lists = yield self._generate_sync_entry_for_device_list( + device_lists = await self._generate_sync_entry_for_device_list( sync_result_builder, newly_joined_rooms=newly_joined_rooms, newly_joined_or_invited_users=newly_joined_or_invited_users, @@ -987,11 +979,11 @@ class SyncHandler(object): device_id = sync_config.device_id one_time_key_counts = {} if device_id: - one_time_key_counts = yield self.store.count_e2e_one_time_keys( + one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - yield self._generate_sync_entry_for_groups(sync_result_builder) + await self._generate_sync_entry_for_groups(sync_result_builder) # debug for https://github.com/matrix-org/synapse/issues/4422 for joined_room in sync_result_builder.joined: @@ -1015,18 +1007,17 @@ class SyncHandler(object): ) @measure_func("_generate_sync_entry_for_groups") - @defer.inlineCallbacks - def _generate_sync_entry_for_groups(self, sync_result_builder): + async def _generate_sync_entry_for_groups(self, sync_result_builder): user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token if since_token and since_token.groups_key: - results = yield self.store.get_groups_changes_for_user( + results = self.store.get_groups_changes_for_user( user_id, since_token.groups_key, now_token.groups_key ) else: - results = yield self.store.get_all_groups_for_user( + results = await self.store.get_all_groups_for_user( user_id, now_token.groups_key ) @@ -1059,8 +1050,7 @@ class SyncHandler(object): ) @measure_func("_generate_sync_entry_for_device_list") - @defer.inlineCallbacks - def _generate_sync_entry_for_device_list( + async def _generate_sync_entry_for_device_list( self, sync_result_builder, newly_joined_rooms, @@ -1108,32 +1098,32 @@ class SyncHandler(object): # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) # Step 1a, check for changes in devices of users we share a room with - users_that_have_changed = yield self.store.get_users_whose_devices_changed( + users_that_have_changed = await self.store.get_users_whose_devices_changed( since_token.device_list_key, users_who_share_room ) # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: - joined_users = yield self.state.get_current_users_in_room(room_id) + joined_users = await self.state.get_current_users_in_room(room_id) newly_joined_or_invited_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they # weren't in the previous sync *or* they left and rejoined. users_that_have_changed.update(newly_joined_or_invited_users) - user_signatures_changed = yield self.store.get_users_whose_signatures_changed( + user_signatures_changed = await self.store.get_users_whose_signatures_changed( user_id, since_token.device_list_key ) users_that_have_changed.update(user_signatures_changed) # Now find users that we no longer track for room_id in newly_left_rooms: - left_users = yield self.state.get_current_users_in_room(room_id) + left_users = await self.state.get_current_users_in_room(room_id) newly_left_users.update(left_users) # Remove any users that we still share a room with. @@ -1143,8 +1133,7 @@ class SyncHandler(object): else: return DeviceLists(changed=[], left=[]) - @defer.inlineCallbacks - def _generate_sync_entry_for_to_device(self, sync_result_builder): + async def _generate_sync_entry_for_to_device(self, sync_result_builder): """Generates the portion of the sync response. Populates `sync_result_builder` with the result. @@ -1165,14 +1154,14 @@ class SyncHandler(object): # We only delete messages when a new message comes in, but that's # fine so long as we delete them at some point. - deleted = yield self.store.delete_messages_for_device( + deleted = await self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) logger.debug( "Deleted %d to-device messages up to %d", deleted, since_stream_id ) - messages, stream_id = yield self.store.get_new_messages_for_device( + messages, stream_id = await self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key ) @@ -1190,8 +1179,7 @@ class SyncHandler(object): else: sync_result_builder.to_device = [] - @defer.inlineCallbacks - def _generate_sync_entry_for_account_data(self, sync_result_builder): + async def _generate_sync_entry_for_account_data(self, sync_result_builder): """Generates the account data portion of the sync response. Populates `sync_result_builder` with the result. @@ -1209,25 +1197,25 @@ class SyncHandler(object): ( account_data, account_data_by_room, - ) = yield self.store.get_updated_account_data_for_user( + ) = self.store.get_updated_account_data_for_user( user_id, since_token.account_data_key ) - push_rules_changed = yield self.store.have_push_rules_changed_for_user( + push_rules_changed = await self.store.have_push_rules_changed_for_user( user_id, int(since_token.push_rules_key) ) if push_rules_changed: - account_data["m.push_rules"] = yield self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) else: ( account_data, account_data_by_room, - ) = yield self.store.get_account_data_for_user(sync_config.user.to_string()) + ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) - account_data["m.push_rules"] = yield self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) @@ -1242,8 +1230,7 @@ class SyncHandler(object): return account_data_by_room - @defer.inlineCallbacks - def _generate_sync_entry_for_presence( + async def _generate_sync_entry_for_presence( self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ): """Generates the presence portion of the sync response. Populates the @@ -1271,7 +1258,7 @@ class SyncHandler(object): presence_key = None include_offline = False - presence, presence_key = yield presence_source.get_new_events( + presence, presence_key = await presence_source.get_new_events( user=user, from_key=presence_key, is_guest=sync_config.is_guest, @@ -1283,12 +1270,12 @@ class SyncHandler(object): extra_users_ids = set(newly_joined_or_invited_users) for room_id in newly_joined_rooms: - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) if extra_users_ids: - states = yield self.presence_handler.get_states(extra_users_ids) + states = await self.presence_handler.get_states(extra_users_ids) presence.extend(states) # Deduplicate the presence entries so that there's at most one per user @@ -1298,8 +1285,9 @@ class SyncHandler(object): sync_result_builder.presence = presence - @defer.inlineCallbacks - def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): + async def _generate_sync_entry_for_rooms( + self, sync_result_builder, account_data_by_room + ): """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1321,7 +1309,7 @@ class SyncHandler(object): if block_all_room_ephemeral: ephemeral_by_room = {} else: - now_token, ephemeral_by_room = yield self.ephemeral_by_room( + now_token, ephemeral_by_room = await self.ephemeral_by_room( sync_result_builder, now_token=sync_result_builder.now_token, since_token=sync_result_builder.since_token, @@ -1333,16 +1321,16 @@ class SyncHandler(object): since_token = sync_result_builder.since_token if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: - have_changed = yield self._have_rooms_changed(sync_result_builder) + have_changed = await self._have_rooms_changed(sync_result_builder) if not have_changed: - tags_by_room = yield self.store.get_updated_tags( + tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) if not tags_by_room: logger.debug("no-oping sync") return [], [], [], [] - ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( + ignored_account_data = await self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id ) @@ -1352,18 +1340,18 @@ class SyncHandler(object): ignored_users = frozenset() if since_token: - res = yield self._get_rooms_changed(sync_result_builder, ignored_users) + res = await self._get_rooms_changed(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms, newly_left_rooms = res - tags_by_room = yield self.store.get_updated_tags( + tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - res = yield self._get_all_rooms(sync_result_builder, ignored_users) + res = await self._get_all_rooms(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms = res newly_left_rooms = [] - tags_by_room = yield self.store.get_tags_for_user(user_id) + tags_by_room = await self.store.get_tags_for_user(user_id) def handle_room_entries(room_entry): return self._generate_room_entry( @@ -1376,7 +1364,7 @@ class SyncHandler(object): always_include=sync_result_builder.full_state, ) - yield concurrently_execute(handle_room_entries, room_entries, 10) + await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) @@ -1410,8 +1398,7 @@ class SyncHandler(object): newly_left_users, ) - @defer.inlineCallbacks - def _have_rooms_changed(self, sync_result_builder): + async def _have_rooms_changed(self, sync_result_builder): """Returns whether there may be any new events that should be sent down the sync. Returns True if there are. """ @@ -1422,7 +1409,7 @@ class SyncHandler(object): assert since_token # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) @@ -1435,8 +1422,7 @@ class SyncHandler(object): return True return False - @defer.inlineCallbacks - def _get_rooms_changed(self, sync_result_builder, ignored_users): + async def _get_rooms_changed(self, sync_result_builder, ignored_users): """Gets the the changes that have happened since the last sync. Args: @@ -1461,7 +1447,7 @@ class SyncHandler(object): assert since_token # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) @@ -1499,11 +1485,11 @@ class SyncHandler(object): continue if room_id in sync_result_builder.joined_room_ids or has_join: - old_state_ids = yield self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None if old_mem_ev_id: - old_mem_ev = yield self.store.get_event( + old_mem_ev = await self.store.get_event( old_mem_ev_id, allow_none=True ) @@ -1536,13 +1522,13 @@ class SyncHandler(object): newly_left_rooms.append(room_id) else: if not old_state_ids: - old_state_ids = yield self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get( (EventTypes.Member, user_id), None ) old_mem_ev = None if old_mem_ev_id: - old_mem_ev = yield self.store.get_event( + old_mem_ev = await self.store.get_event( old_mem_ev_id, allow_none=True ) if old_mem_ev and old_mem_ev.membership == Membership.JOIN: @@ -1566,7 +1552,7 @@ class SyncHandler(object): if leave_events: leave_event = leave_events[-1] - leave_stream_token = yield self.store.get_stream_token_for_event( + leave_stream_token = await self.store.get_stream_token_for_event( leave_event.event_id ) leave_token = since_token.copy_and_replace( @@ -1603,7 +1589,7 @@ class SyncHandler(object): timeline_limit = sync_config.filter_collection.timeline_limit() # Get all events for rooms we're currently joined to. - room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, to_key=now_token.room_key, @@ -1652,8 +1638,7 @@ class SyncHandler(object): return room_entries, invited, newly_joined_rooms, newly_left_rooms - @defer.inlineCallbacks - def _get_all_rooms(self, sync_result_builder, ignored_users): + async def _get_all_rooms(self, sync_result_builder, ignored_users): """Returns entries for all rooms for the user. Args: @@ -1677,7 +1662,7 @@ class SyncHandler(object): Membership.BAN, ) - room_list = yield self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_user_where_membership_is( user_id=user_id, membership_list=membership_list ) @@ -1700,7 +1685,7 @@ class SyncHandler(object): elif event.membership == Membership.INVITE: if event.sender in ignored_users: continue - invite = yield self.store.get_event(event.event_id) + invite = await self.store.get_event(event.event_id) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) elif event.membership in (Membership.LEAVE, Membership.BAN): # Always send down rooms we were banned or kicked from. @@ -1726,8 +1711,7 @@ class SyncHandler(object): return room_entries, invited, [] - @defer.inlineCallbacks - def _generate_room_entry( + async def _generate_room_entry( self, sync_result_builder, ignored_users, @@ -1769,7 +1753,7 @@ class SyncHandler(object): since_token = room_builder.since_token upto_token = room_builder.upto_token - batch = yield self._load_filtered_recents( + batch = await self._load_filtered_recents( room_id, sync_config, now_token=upto_token, @@ -1796,7 +1780,7 @@ class SyncHandler(object): # tag was added by synapse e.g. for server notice rooms. if full_state: user_id = sync_result_builder.sync_config.user.to_string() - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) # If there aren't any tags, don't send the empty tags list down # sync @@ -1821,7 +1805,7 @@ class SyncHandler(object): ): return - state = yield self.compute_state_delta( + state = await self.compute_state_delta( room_id, batch, sync_config, since_token, now_token, full_state=full_state ) @@ -1844,7 +1828,7 @@ class SyncHandler(object): ) or since_token is None ): - summary = yield self.compute_summary( + summary = await self.compute_summary( room_id, sync_config, batch, state, now_token ) @@ -1861,7 +1845,7 @@ class SyncHandler(object): ) if room_sync or always_include: - notifs = yield self.unread_notifs_for_room_id(room_id, sync_config) + notifs = await self.unread_notifs_for_room_id(room_id, sync_config) if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] @@ -1887,8 +1871,7 @@ class SyncHandler(object): else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) - @defer.inlineCallbacks - def get_rooms_for_user_at(self, user_id, stream_ordering): + async def get_rooms_for_user_at(self, user_id, stream_ordering): """Get set of joined rooms for a user at the given stream ordering. The stream ordering *must* be recent, otherwise this may throw an @@ -1903,7 +1886,7 @@ class SyncHandler(object): Deferred[frozenset[str]]: Set of room_ids the user is in at given stream_ordering. """ - joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id) + joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id) joined_room_ids = set() @@ -1921,10 +1904,10 @@ class SyncHandler(object): logger.info("User joined room after current token: %s", room_id) - extrems = yield self.store.get_forward_extremeties_for_room( + extrems = await self.store.get_forward_extremeties_for_room( room_id, stream_ordering ) - users_in_room = yield self.state.get_current_users_in_room(room_id, extrems) + users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: joined_room_ids.add(room_id) diff --git a/synapse/notifier.py b/synapse/notifier.py index af161a81d7..5f5f765bea 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -304,8 +304,7 @@ class Notifier(object): without waking up any of the normal user event streams""" self.notify_replication() - @defer.inlineCallbacks - def wait_for_events( + async def wait_for_events( self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START ): """Wait until the callback returns a non empty response or the @@ -313,9 +312,9 @@ class Notifier(object): """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - current_token = yield self.event_sources.get_current_token() + current_token = await self.event_sources.get_current_token() if room_ids is None: - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( user_id=user_id, rooms=room_ids, @@ -344,11 +343,11 @@ class Notifier(object): self.hs.get_reactor(), ) with PreserveLoggingContext(): - yield listener.deferred + await listener.deferred current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) if result: break @@ -364,12 +363,11 @@ class Notifier(object): # This happened if there was no timeout or if the timeout had # already expired. current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) return result - @defer.inlineCallbacks - def get_events_for( + async def get_events_for( self, user, pagination_config, @@ -391,15 +389,14 @@ class Notifier(object): """ from_token = pagination_config.from_token if not from_token: - from_token = yield self.event_sources.get_current_token() + from_token = await self.event_sources.get_current_token() limit = pagination_config.limit - room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id) + room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) is_peeking = not is_joined - @defer.inlineCallbacks - def check_for_updates(before_token, after_token): + async def check_for_updates(before_token, after_token): if not after_token.is_after(before_token): return EventStreamResult([], (from_token, from_token)) @@ -415,7 +412,7 @@ class Notifier(object): if only_keys and name not in only_keys: continue - new_events, new_key = yield source.get_new_events( + new_events, new_key = await source.get_new_events( user=user, from_key=getattr(from_token, keyname), limit=limit, @@ -425,7 +422,7 @@ class Notifier(object): ) if name == "room": - new_events = yield filter_events_for_client( + new_events = await filter_events_for_client( self.storage, user.to_string(), new_events, @@ -461,7 +458,7 @@ class Notifier(object): user_id_for_stream, ) - result = yield self.wait_for_events( + result = await self.wait_for_events( user_id_for_stream, timeout, check_for_updates, diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 3286804322..63ddaaba87 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from functools import wraps @@ -64,12 +65,22 @@ def measure_func(name=None): def wrapper(func): block_name = func.__name__ if name is None else name - @wraps(func) - @defer.inlineCallbacks - def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = yield func(self, *args, **kwargs) - return r + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = await func(self, *args, **kwargs) + return r + + else: + + @wraps(func) + @defer.inlineCallbacks + def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = yield func(self, *args, **kwargs) + return r return measured_func diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 31f54bbd7d..758ee071a5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,54 +12,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION -from synapse.handlers.sync import SyncConfig, SyncHandler +from synapse.handlers.sync import SyncConfig from synapse.types import UserID import tests.unittest import tests.utils -from tests.utils import setup_test_homeserver -class SyncTestCase(tests.unittest.TestCase): +class SyncTestCase(tests.unittest.HomeserverTestCase): """ Tests Sync Handler. """ - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.sync_handler = SyncHandler(self.hs) + def prepare(self, reactor, clock, hs): + self.hs = hs + self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() - @defer.inlineCallbacks def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:server" user_id2 = "@user2:server" sync_config = self._generate_sync_config(user_id1) + self.reactor.advance(100) # So we get not 0 time self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 1 # Check that the happy case does not throw errors - yield self.store.upsert_monthly_active_user(user_id1) - yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.get_success(self.store.upsert_monthly_active_user(user_id1)) + self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) # Test that global lock works self.hs.config.hs_disabled = True - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.hs.config.hs_disabled = False sync_config = self._generate_sync_config(user_id2) - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def _generate_sync_config(self, user_id): return SyncConfig( diff --git a/tests/unittest.py b/tests/unittest.py index 295573bc46..a1bdd963e6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -18,6 +18,7 @@ import gc import hashlib import hmac +import inspect import logging import time @@ -25,7 +26,7 @@ from mock import Mock from canonicaljson import json -from twisted.internet.defer import Deferred, succeed +from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.python.threadpool import ThreadPool from twisted.trial import unittest @@ -415,6 +416,8 @@ class HomeserverTestCase(TestCase): self.reactor.pump([by] * 100) def get_success(self, d, by=0.0): + if inspect.isawaitable(d): + d = ensureDeferred(d) if not isinstance(d, Deferred): return d self.pump(by=by) @@ -424,6 +427,8 @@ class HomeserverTestCase(TestCase): """ Run a Deferred and get a Failure from it. The failure must be of the type `exc`. """ + if inspect.isawaitable(d): + d = ensureDeferred(d) if not isinstance(d, Deferred): return d self.pump() -- cgit 1.5.1 From b3a4e35ca84a29fe4ccdfb1125ed098c68405d6c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 10:14:59 +0000 Subject: Fixup functions to consistently return deferreds --- synapse/handlers/sync.py | 6 +++--- synapse/handlers/typing.py | 2 +- synapse/storage/data_stores/main/account_data.py | 2 +- synapse/storage/data_stores/main/group_server.py | 4 ++-- tests/handlers/test_typing.py | 24 ++++++++++++++++++------ tests/rest/client/v1/test_typing.py | 4 +++- 6 files changed, 28 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 12751fd8c0..2d3b8ba73c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -339,7 +339,7 @@ class SyncHandler(object): room_ids = sync_result_builder.joined_room_ids typing_source = self.event_sources.sources["typing"] - typing, typing_key = typing_source.get_new_events( + typing, typing_key = await typing_source.get_new_events( user=sync_config.user, from_key=typing_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -1013,7 +1013,7 @@ class SyncHandler(object): now_token = sync_result_builder.now_token if since_token and since_token.groups_key: - results = self.store.get_groups_changes_for_user( + results = await self.store.get_groups_changes_for_user( user_id, since_token.groups_key, now_token.groups_key ) else: @@ -1197,7 +1197,7 @@ class SyncHandler(object): ( account_data, account_data_by_room, - ) = self.store.get_updated_account_data_for_user( + ) = await self.store.get_updated_account_data_for_user( user_id, since_token.account_data_key ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 856337b7e2..6f78454322 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -313,7 +313,7 @@ class TypingNotificationEventSource(object): events.append(self._make_event_for(room_id)) - return events, handler._latest_room_serial + return defer.succeed((events, handler._latest_room_serial)) def get_current_key(self): return self.get_typing_handler()._latest_room_serial diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index b0d22faf3f..ed97b3ffe5 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -250,7 +250,7 @@ class AccountDataWorkerStore(SQLBaseStore): user_id, int(stream_id) ) if not changed: - return {}, {} + return defer.succeed(({}, {})) return self.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 9e1d12bcb7..d29155a3b5 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -1109,7 +1109,7 @@ class GroupServerStore(SQLBaseStore): user_id, from_token ) if not has_changed: - return [] + return defer.succeed([]) def _get_groups_changes_for_user_txn(txn): sql = """ @@ -1139,7 +1139,7 @@ class GroupServerStore(SQLBaseStore): from_token ) if not has_changed: - return [] + return defer.succeed([]) def _get_all_groups_changes_txn(txn): sql = """ diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f6d8660285..92b8726093 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -163,7 +163,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -227,7 +229,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -279,7 +283,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -300,7 +306,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -317,7 +325,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -335,7 +345,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 30fb77bac8..4bc3aaf02d 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -109,7 +109,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + events = self.get_success( + self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + ) self.assertEquals( events[0], [ -- cgit 1.5.1 From 9a4fb457cf5918c85068ea249cd2d58b3e2e3cfc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 13:08:40 +0000 Subject: Change DataStores to accept 'database' param. --- synapse/app/federation_sender.py | 5 +++-- synapse/app/user_dir.py | 5 +++-- synapse/replication/slave/storage/_base.py | 5 +++-- synapse/replication/slave/storage/account_data.py | 5 +++-- synapse/replication/slave/storage/client_ips.py | 5 +++-- synapse/replication/slave/storage/deviceinbox.py | 5 +++-- synapse/replication/slave/storage/devices.py | 5 +++-- synapse/replication/slave/storage/events.py | 5 +++-- synapse/replication/slave/storage/filtering.py | 5 +++-- synapse/replication/slave/storage/groups.py | 5 +++-- synapse/replication/slave/storage/presence.py | 5 +++-- synapse/replication/slave/storage/push_rule.py | 5 +++-- synapse/replication/slave/storage/pushers.py | 5 +++-- synapse/replication/slave/storage/receipts.py | 5 +++-- synapse/replication/slave/storage/room.py | 5 +++-- synapse/storage/_base.py | 2 +- synapse/storage/data_stores/main/__init__.py | 5 +++-- synapse/storage/data_stores/main/account_data.py | 9 +++++---- synapse/storage/data_stores/main/appservice.py | 5 +++-- synapse/storage/data_stores/main/client_ips.py | 9 +++++---- synapse/storage/data_stores/main/deviceinbox.py | 9 +++++---- synapse/storage/data_stores/main/devices.py | 9 +++++---- synapse/storage/data_stores/main/event_federation.py | 5 +++-- synapse/storage/data_stores/main/event_push_actions.py | 9 +++++---- synapse/storage/data_stores/main/events.py | 5 +++-- synapse/storage/data_stores/main/events_bg_updates.py | 5 +++-- synapse/storage/data_stores/main/events_worker.py | 5 +++-- synapse/storage/data_stores/main/media_repository.py | 11 +++++++---- synapse/storage/data_stores/main/monthly_active_users.py | 7 ++++--- synapse/storage/data_stores/main/push_rule.py | 5 +++-- synapse/storage/data_stores/main/receipts.py | 9 +++++---- synapse/storage/data_stores/main/registration.py | 13 +++++++------ synapse/storage/data_stores/main/room.py | 9 +++++---- synapse/storage/data_stores/main/roommember.py | 13 +++++++------ synapse/storage/data_stores/main/search.py | 9 +++++---- synapse/storage/data_stores/main/state.py | 13 +++++++------ synapse/storage/data_stores/main/stats.py | 5 +++-- synapse/storage/data_stores/main/stream.py | 5 +++-- synapse/storage/data_stores/main/transactions.py | 5 +++-- synapse/storage/data_stores/main/user_directory.py | 9 +++++---- tests/storage/test_appservice.py | 5 +++-- 41 files changed, 156 insertions(+), 114 deletions(-) (limited to 'tests') diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 448e45e00f..f24920a7d6 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer @@ -59,8 +60,8 @@ class FederationSenderSlaveStore( SlavedDeviceStore, SlavedPresenceStore, ): - def __init__(self, db_conn, hs): - super(FederationSenderSlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs) # We pull out the current federation stream position now so that we # always have a known value for the federation position in memory so diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index b6d4481725..c01fb34a9b 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import ( from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree @@ -60,8 +61,8 @@ class UserDirectorySlaveStore( UserDirectoryStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): - super(UserDirectorySlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 6ece1d6745..b91a528245 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -20,6 +20,7 @@ import six from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker @@ -35,8 +36,8 @@ def __func__(inp): class BaseSlavedStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id" diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index bc2f6a12ae..ebe94909cb 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.storage.data_stores.main.tags import TagsWorkerStore +from synapse.storage.database import Database class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id" ) - super(SlavedAccountDataStore, self).__init__(db_conn, hs) + super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index b4f58cea19..fbf996e33a 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -21,8 +22,8 @@ from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedClientIpStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 9fb6c5c6ff..0c237c6e0f 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -16,13 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id" ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index de50748c30..dc625e0d7a 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d0a0eaf75b..29f35b9915 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.data_stores.main.stream import StreamWorkerStore from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -59,13 +60,13 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) - super(SlavedEventStore, self).__init__(db_conn, hs) + super(SlavedEventStore, self).__init__(database, db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 5c84ebd125..bcb0688954 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -14,13 +14,14 @@ # limitations under the License. from synapse.storage.data_stores.main.filtering import FilteringStore +from synapse.storage.database import Database from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedFilteringStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 28a46edd28..69a4ae42f9 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage import DataStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -21,8 +22,8 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedGroupServerStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 747ced0c84..f552e7c972 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -15,6 +15,7 @@ from synapse.storage import DataStore from synapse.storage.data_stores.main.presence import PresenceStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -22,8 +23,8 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPresenceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 3655f05e54..eebd5a1fb6 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,17 +15,18 @@ # limitations under the License. from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore +from synapse.storage.database import Database from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" ) - super(SlavedPushRuleStore, self).__init__(db_conn, hs) + super(SlavedPushRuleStore, self).__init__(database, db_conn, hs) def get_push_rules_stream_token(self): return ( diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index b4331d0799..f22c2d44a3 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -15,14 +15,15 @@ # limitations under the License. from synapse.storage.data_stores.main.pusher import PusherWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPusherStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 43d823c601..d40dc6e1f5 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -29,14 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(db_conn, hs) + super(SlavedReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index d9ad386b28..3a20f45316 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -14,14 +14,15 @@ # limitations under the License. from synapse.storage.data_stores.main.room import RoomWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7e27d4e97..f9e7f9a71e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -37,7 +37,7 @@ class SQLBaseStore(object): per data store (and not one per physical database). """ - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 6adb8adb04..7f5fd81bcf 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -20,6 +20,7 @@ import logging import time from synapse.api.constants import PresenceState +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( ChainedIdGenerator, @@ -111,7 +112,7 @@ class DataStore( RelationsStore, CacheInvalidationStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine @@ -169,7 +170,7 @@ class DataStore( else: self._cache_id_gen = None - super(DataStore, self).__init__(db_conn, hs) + super(DataStore, self).__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index a96fe9485c..44d20c19bf 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(db_conn, hs) + super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore): class AccountDataStore(AccountDataWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) - super(AccountDataStore, self).__init__(db_conn, hs) + super(AccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 6b2e12719c..b2f39649fd 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) + super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 7b470a58f1..320c5b0f07 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -33,8 +34,8 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "user_ips_device_index", @@ -363,13 +364,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpStore(ClientIpBackgroundUpdateStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) - super(ClientIpStore, self).__init__(db_conn, hs) + super(ClientIpStore, self).__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.user_ips_max_age diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 3c9f09301a..85cfa16850 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -210,8 +211,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "device_inbox_stream_index", @@ -241,8 +242,8 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 91ddaf137e..9a828231c4 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -31,6 +31,7 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter from synapse.util.caches.descriptors import ( @@ -642,8 +643,8 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "device_lists_stream_idx", @@ -692,8 +693,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 31d2e8eb28..1f517e8fad 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -28,6 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -491,8 +492,8 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, db_conn, hs): - super(EventFederationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventFederationStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index eec054cd48..9988a6d3fc 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None @@ -611,8 +612,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, db_conn, hs): - super(EventPushActionsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d644c82784..da1529f6ea 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -41,6 +41,7 @@ from synapse.storage._base import make_in_list_sql_clause from synapse.storage.data_stores.main.event_federation import EventFederationStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.database import Database from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -95,8 +96,8 @@ def _retry_on_integrity_error(func): class EventsStore( StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, ): - def __init__(self, db_conn, hs): - super(EventsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsStore, self).__init__(database, db_conn, hs) # Collect metrics on the number of forward extremities that exist. # Counter of number of extremities to count diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index cb1fc30c31..efee17b929 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventContentFields from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -33,8 +34,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index e041fc5eac..9ee117ce0f 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -33,6 +33,7 @@ from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import Cache @@ -55,8 +56,8 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) class EventsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsWorkerStore, self).__init__(database, db_conn, hs) self._get_event_cache = Cache( "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 03c9c6f8ae..80ca36dedf 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryBackgroundUpdateStore, self).__init__( + database, db_conn, hs + ) self.db.updates.register_background_index_update( update_name="local_media_repository_url_idx", @@ -31,8 +34,8 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, db_conn, hs): - super(MediaRepositoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryStore, self).__init__(database, db_conn, hs) def get_local_media(self, media_id): """Get the metadata for a local piece of media diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 34bf3a1880..27158534cb 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -27,13 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersStore(SQLBaseStore): - def __init__(self, dbconn, hs): - super(MonthlyActiveUsersStore, self).__init__(None, hs) + def __init__(self, database: Database, db_conn, hs): + super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number self.db.new_transaction( - dbconn, + db_conn, "initialise_mau_threepids", [], [], diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index de682cc63a..5ba13aa973 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.appservice import ApplicationServiceWorker from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.database import Database from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -72,8 +73,8 @@ class PushRulesWorkerStore( # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index ac2d45bd5c..96e54d145e 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -315,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(db_conn, hs) + super(ReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 1ef143c6d8..5e8ecac0ea 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -27,6 +27,7 @@ from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -36,8 +37,8 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() @@ -794,8 +795,8 @@ class RegistrationWorkerStore(SQLBaseStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config @@ -920,8 +921,8 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RegistrationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index da42dae243..0148be20d3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.search import SearchStore +from synapse.storage.database import Database from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -361,8 +362,8 @@ class RoomWorkerStore(SQLBaseStore): class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.config = hs.config @@ -440,8 +441,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 929f6b0d39..92e3b9c512 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -32,6 +32,7 @@ from synapse.storage._base import ( make_in_list_sql_clause, ) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( GetRoomsForUserWithStreamOrdering, @@ -54,8 +55,8 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the # background update still running? @@ -835,8 +836,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) @@ -991,8 +992,8 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberStore, self).__init__(database, db_conn, hs) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index ffa1817e64..4eec2fae5e 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -42,8 +43,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) if not hs.config.enable_search: return @@ -342,8 +343,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(SearchStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchStore, self).__init__(database, db_conn, hs) def store_event_search_txn(self, txn, event, key, value): """Add event to the search table diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 7d5a9f8128..9ef7b48c74 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -28,6 +28,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter from synapse.util.caches import get_cache_factor_for, intern_string @@ -213,8 +214,8 @@ class StateGroupWorkerStore( STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -1029,8 +1030,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" - def __init__(self, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, @@ -1245,8 +1246,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateStore, self).__init__(database, db_conn, hs) def _store_event_state_mappings_txn( self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 40579bf965..7bc186e9a1 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -22,6 +22,7 @@ from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.util.caches.descriptors import cached @@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): - def __init__(self, db_conn, hs): - super(StatsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StatsStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 2ff8c57109..140da8dad6 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -47,6 +47,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -251,8 +252,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(StreamWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StreamWorkerStore, self).__init__(database, db_conn, hs) events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self.db.get_cache_dict( diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index c0d155a43c..5b07c2fbc0 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache # py2 sqlite has buffer hardcoded as only binary type, so we must use it, @@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, db_conn, hs): - super(TransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TransactionStore, self).__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 62ffb34b29..90c180ec6d 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.storage.data_stores.main.state import StateFilter from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -37,8 +38,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -549,8 +550,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryStore, self).__init__(database, db_conn, hs) def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index dfeea24599..1679112d82 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.storage.database import Database from tests import unittest from tests.utils import setup_test_homeserver @@ -382,8 +383,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, db_conn, hs): - super(TestTransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TestTransactionStore, self).__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): -- cgit 1.5.1 From 852f80d8a697c9d556b1b2641a2e4c1797cbbb46 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 16:02:50 +0000 Subject: Fixup tests --- tests/replication/slave/storage/_base.py | 5 ++++- tests/storage/test_appservice.py | 12 +++++++----- tests/storage/test_base.py | 3 ++- tests/storage/test_profile.py | 3 +-- tests/storage/test_user_directory.py | 4 +--- tests/test_federation.py | 16 +++++----------- 6 files changed, 20 insertions(+), 23 deletions(-) (limited to 'tests') diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index e7472e3a93..3dae83c543 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,6 +20,7 @@ from synapse.replication.tcp.client import ( ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.storage.database import Database from tests import unittest from tests.server import FakeTransport @@ -42,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() - self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) + self.slaved_store = self.STORE_TYPE( + Database(hs), self.hs.get_db_conn(), self.hs + ) self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1679112d82..2e521e9ab7 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -55,7 +55,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(hs.get_db_conn(), hs) + database = Database(hs) + self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -124,7 +125,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(hs.get_db_conn(), hs) + database = Database(hs) + self.store = TestTransactionStore(database, hs.get_db_conn(), hs) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -417,7 +419,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -433,7 +435,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) @@ -454,7 +456,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 7915d48a9e..537cfe9f64 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -21,6 +21,7 @@ from mock import Mock from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from tests import unittest @@ -59,7 +60,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): "test", db_pool=self.db_pool, config=config, database_engine=fake_engine ) - self.datastore = SQLBaseStore(None, hs) + self.datastore = SQLBaseStore(Database(hs), None, hs) @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 24c7fe16c3..9b6f7211ae 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.data_stores.main.profile import ProfileStore from synapse.types import UserID from tests import unittest @@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(hs.get_db_conn(), hs) + self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7eea57c0e2..6a545d2eb0 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from synapse.storage.data_stores.main.user_directory import UserDirectoryStore - from tests import unittest from tests.utils import setup_test_homeserver @@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) + self.store = self.hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_federation.py b/tests/test_federation.py index 7d82b58466..ad165d7295 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -33,6 +33,8 @@ class MessageAcceptTests(unittest.TestCase): self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] + self.store = self.homeserver.get_datastore() + # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( @@ -77,10 +79,7 @@ class MessageAcceptTests(unittest.TestCase): # Make sure we actually joined the room self.assertEqual( self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0], "$join:test.serv", ) @@ -100,10 +99,7 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0] # Now lie about an event @@ -141,7 +137,5 @@ class MessageAcceptTests(unittest.TestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id - ) + extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") -- cgit 1.5.1 From adfdd82b21ae296ed77453b2f51d55414890f162 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 9 Dec 2019 13:59:27 +0000 Subject: Back out perf regression from get_cross_signing_keys_from_cache. (#6494) Back out cross-signing code added in Synapse 1.5.0, which caused a performance regression. --- changelog.d/6494.bugfix | 1 + synapse/handlers/e2e_keys.py | 38 ++++++++------------------------------ sytest-blacklist | 3 +++ tests/handlers/test_e2e_keys.py | 8 ++++++++ 4 files changed, 20 insertions(+), 30 deletions(-) create mode 100644 changelog.d/6494.bugfix (limited to 'tests') diff --git a/changelog.d/6494.bugfix b/changelog.d/6494.bugfix new file mode 100644 index 0000000000..78726d5d7f --- /dev/null +++ b/changelog.d/6494.bugfix @@ -0,0 +1 @@ +Back out cross-signing code added in Synapse 1.5.0, which caused a performance regression. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 28c12753c1..57a10daefd 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -264,7 +264,6 @@ class E2eKeysHandler(object): return ret - @defer.inlineCallbacks def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database @@ -284,35 +283,14 @@ class E2eKeysHandler(object): self_signing_keys = {} user_signing_keys = {} - for user_id in query: - # XXX: consider changing the store functions to allow querying - # multiple users simultaneously. - key = yield self.store.get_e2e_cross_signing_key( - user_id, "master", from_user_id - ) - if key: - master_keys[user_id] = key - - key = yield self.store.get_e2e_cross_signing_key( - user_id, "self_signing", from_user_id - ) - if key: - self_signing_keys[user_id] = key - - # users can see other users' master and self-signing keys, but can - # only see their own user-signing keys - if from_user_id == user_id: - key = yield self.store.get_e2e_cross_signing_key( - user_id, "user_signing", from_user_id - ) - if key: - user_signing_keys[user_id] = key - - return { - "master_keys": master_keys, - "self_signing_keys": self_signing_keys, - "user_signing_keys": user_signing_keys, - } + # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486 + return defer.succeed( + { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } + ) @trace @defer.inlineCallbacks diff --git a/sytest-blacklist b/sytest-blacklist index 411cce0692..79b2d4402a 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -33,3 +33,6 @@ New federated private chats get full presence information (SYN-115) # Blacklisted due to https://github.com/matrix-org/matrix-doc/pull/2314 removing # this requirement from the spec Inbound federation of state requires event_id as a mandatory paramater + +# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands +Can upload self-signing keys diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 854eb6c024..fdfa2cbbc4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -183,6 +183,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) + test_replace_master_key.skip = ( + "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" + ) + @defer.inlineCallbacks def test_reupload_signatures(self): """re-uploading a signature should not fail""" @@ -503,3 +507,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ], other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) + + test_upload_signatures.skip = ( + "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" + ) -- cgit 1.5.1