From 476a89707ada05c0767324063d9c5814547d3ae1 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 10 Jun 2020 17:55:03 +0100 Subject: Fix tests --- tests/replication/slave/storage/test_events.py | 6 ++--- tests/storage/test_event_push_actions.py | 32 +++++++++++++++----------- 2 files changed, 21 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 1a88c7fb80..bc667454c1 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 0}, + {"highlight_count": 0, "notify_count": 0, "unread_count": 0}, ) self.persist( @@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 1}, + {"highlight_count": 0, "notify_count": 1, "unread_count": 1}, ) self.persist( @@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 1, "notify_count": 2}, + {"highlight_count": 1, "notify_count": 2, "unread_count": 2}, ) def test_get_rooms_for_user_with_stream_ordering(self): diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b45bc9c115..79a88a1480 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -55,13 +55,17 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): user_id = "@user1235:example.com" @defer.inlineCallbacks - def _assert_counts(noitf_count, highlight_count): + def _assert_counts(unread_count, notif_count, highlight_count): counts = yield self.store.db.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( counts, - {"notify_count": noitf_count, "highlight_count": highlight_count}, + { + "unread_count": unread_count, + "notify_count": notif_count, + "highlight_count": highlight_count, + }, ) @defer.inlineCallbacks @@ -96,23 +100,23 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): stream, ) - yield _assert_counts(0, 0) + yield _assert_counts(0, 0, 0) yield _inject_actions(1, PlAIN_NOTIF) - yield _assert_counts(1, 0) + yield _assert_counts(1, 1, 0) yield _rotate(2) - yield _assert_counts(1, 0) + yield _assert_counts(1, 1, 0) yield _inject_actions(3, PlAIN_NOTIF) - yield _assert_counts(2, 0) + yield _assert_counts(2, 2, 0) yield _rotate(4) - yield _assert_counts(2, 0) + yield _assert_counts(2, 2, 0) yield _inject_actions(5, PlAIN_NOTIF) yield _mark_read(3, 3) - yield _assert_counts(1, 0) + yield _assert_counts(1, 1, 0) yield _mark_read(5, 5) - yield _assert_counts(0, 0) + yield _assert_counts(0, 0, 0) yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) @@ -121,17 +125,17 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): table="event_push_actions", keyvalues={"1": 1}, desc="" ) - yield _assert_counts(1, 0) + yield _assert_counts(1, 1, 0) yield _mark_read(7, 7) - yield _assert_counts(0, 0) + yield _assert_counts(0, 0, 0) yield _inject_actions(8, HIGHLIGHT) - yield _assert_counts(1, 1) + yield _assert_counts(1, 1, 1) yield _rotate(9) - yield _assert_counts(1, 1) + yield _assert_counts(1, 1, 1) yield _rotate(10) - yield _assert_counts(1, 1) + yield _assert_counts(1, 1, 1) @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): -- cgit 1.5.1 From 2a07c5ded67f598376d82c37057ead6571a4276d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 12 Jun 2020 11:08:05 +0100 Subject: Test that a mark_unread action updates the right counter --- tests/storage/test_event_push_actions.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 79a88a1480..1e6ec95315 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -17,11 +17,16 @@ from mock import Mock from twisted.internet import defer +from tests import unittest import tests.unittest import tests.utils USER_ID = "@user:example.com" +MARK_UNREAD = [ + "org.matrix.msc2625.mark_unread", + {"set_tweak": "highlight", "value": False}, +] PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}] HIGHLIGHT = [ "notify", @@ -49,6 +54,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): USER_ID, 0, 1000, 20 ) + @unittest.DEBUG @defer.inlineCallbacks def test_count_aggregation(self): room_id = "!foo:example.com" @@ -130,12 +136,17 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _mark_read(7, 7) yield _assert_counts(0, 0, 0) - yield _inject_actions(8, HIGHLIGHT) - yield _assert_counts(1, 1, 1) + yield _inject_actions(8, MARK_UNREAD) + yield _assert_counts(1, 0, 0) yield _rotate(9) - yield _assert_counts(1, 1, 1) - yield _rotate(10) - yield _assert_counts(1, 1, 1) + yield _assert_counts(1, 0, 0) + + yield _inject_actions(10, HIGHLIGHT) + yield _assert_counts(2, 1, 1) + yield _rotate(11) + yield _assert_counts(2, 1, 1) + yield _rotate(12) + yield _assert_counts(2, 1, 1) @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): -- cgit 1.5.1 From 63d9a00bf11b5d0f50c173258a0d24ddc0fb7bdf Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 12 Jun 2020 11:13:30 +0100 Subject: Remove debug logging --- tests/storage/test_event_push_actions.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 1e6ec95315..303dc8571c 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -17,7 +17,6 @@ from mock import Mock from twisted.internet import defer -from tests import unittest import tests.unittest import tests.utils @@ -54,7 +53,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): USER_ID, 0, 1000, 20 ) - @unittest.DEBUG @defer.inlineCallbacks def test_count_aggregation(self): room_id = "!foo:example.com" -- cgit 1.5.1 From 6b1fa3293d5e834b6b66c4b9d83a5f938cbcabde Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 12 Jun 2020 11:28:26 +0100 Subject: Test that a mark_unread action updates the right counter when using a slave store --- tests/replication/slave/storage/test_events.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'tests') diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index bc667454c1..9837d44995 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -191,6 +191,21 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): {"highlight_count": 1, "notify_count": 2, "unread_count": 2}, ) + self.persist( + type="m.room.message", + msgtype="m.text", + body="world", + push_actions=[ + (USER_ID_2, ["org.matrix.msc2625.mark_unread"]) + ], + ) + self.replicate() + self.check( + "get_unread_event_push_actions_by_room_for_user", + [ROOM_ID, USER_ID_2, event1.event_id], + {"highlight_count": 1, "notify_count": 2, "unread_count": 3}, + ) + def test_get_rooms_for_user_with_stream_ordering(self): """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated by rows in the events stream -- cgit 1.5.1 From 7e80c84902f2d34aff1bb8b4c5833cb33d3dc653 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 12 Jun 2020 11:31:11 +0100 Subject: Lint --- tests/replication/slave/storage/test_events.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 9837d44995..cd8680e812 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -195,9 +195,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.message", msgtype="m.text", body="world", - push_actions=[ - (USER_ID_2, ["org.matrix.msc2625.mark_unread"]) - ], + push_actions=[(USER_ID_2, ["org.matrix.msc2625.mark_unread"])], ) self.replicate() self.check( -- cgit 1.5.1 From 7d2532be36dc116e130ad226a7462bb0e899aca4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 15 Jun 2020 08:44:54 -0400 Subject: Discard RDATA from already seen positions. (#7648) --- changelog.d/7648.bugfix | 1 + synapse/app/generic_worker.py | 5 ++ synapse/replication/tcp/commands.py | 4 +- synapse/replication/tcp/handler.py | 30 ++++++++-- tests/replication/tcp/streams/test_events.py | 74 ++++++++++++++++++----- tests/replication/tcp/streams/test_typing.py | 88 ++++++++++++++++++++++++++-- 6 files changed, 175 insertions(+), 27 deletions(-) create mode 100644 changelog.d/7648.bugfix (limited to 'tests') diff --git a/changelog.d/7648.bugfix b/changelog.d/7648.bugfix new file mode 100644 index 0000000000..ff2417bfb6 --- /dev/null +++ b/changelog.d/7648.bugfix @@ -0,0 +1 @@ +In working mode, ensure that replicated data has not already been received. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index f3ec2a34ec..53c488d211 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -738,6 +738,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): except Exception: logger.exception("Error processing replication") + async def on_position(self, stream_name: str, instance_name: str, token: int): + await super().on_position(stream_name, instance_name, token) + # Also call on_rdata to ensure that stream positions are properly reset. + await self.on_rdata(stream_name, instance_name, token, []) + def stop_pusher(self, user_id, app_id, pushkey): if not self.notify_pushers: return diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index c04f622816..ea5937a20c 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -149,7 +149,7 @@ class RdataCommand(Command): class PositionCommand(Command): - """Sent by the server to tell the client the stream postition without + """Sent by the server to tell the client the stream position without needing to send an RDATA. Format:: @@ -188,7 +188,7 @@ class ErrorCommand(_SimpleCommand): class PingCommand(_SimpleCommand): - """Sent by either side as a keep alive. The data is arbitary (often timestamp) + """Sent by either side as a keep alive. The data is arbitrary (often timestamp) """ NAME = "PING" diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index cbcf46f3ae..e6a2e2598b 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -112,8 +112,8 @@ class ReplicationCommandHandler: "replication_position", clock=self._clock ) - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. + # Map of stream name to batched updates. See RdataCommand for info on + # how batching works. self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. @@ -123,7 +123,8 @@ class ReplicationCommandHandler: # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] - # For each connection, the incoming streams that are coming from that connection + # For each connection, the incoming stream names that are coming from + # that connection. self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] LaterGauge( @@ -310,7 +311,28 @@ class ReplicationCommandHandler: # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) - await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) + + stream = self._streams.get(stream_name) + if not stream: + logger.error("Got RDATA for unknown stream: %s", stream_name) + return + + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) + + # Discard this data if this token is earlier than the current + # position. Note that streams can be reset (in which case you + # expect an earlier token), but that must be preceded by a + # POSITION command. + if cmd.token <= current_token: + logger.debug( + "Discarding RDATA from stream %s at position %s before previous position %s", + stream_name, + cmd.token, + current_token, + ) + else: + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 51bf0ef4e9..097e1653b4 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -17,6 +17,7 @@ from typing import List, Optional from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase +from synapse.replication.tcp.commands import RdataCommand from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT from synapse.replication.tcp.streams.events import ( EventsStreamCurrentStateRow, @@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase): # also one state event state_event = self._inject_state_event() - # tell the notifier to catch up to avoid duplicate rows. - # workaround for https://github.com/matrix-org/synapse/issues/7360 - # FIXME remove this when the above is fixed - self.replicate() - # check we're testing what we think we are: no rows should yet have been # received self.assertEqual([], self.test_handler.received_rdata_rows) @@ -174,11 +170,6 @@ class EventsStreamTestCase(BaseStreamTestCase): # one more bit of state that doesn't get rolled back state2 = self._inject_state_event() - # tell the notifier to catch up to avoid duplicate rows. - # workaround for https://github.com/matrix-org/synapse/issues/7360 - # FIXME remove this when the above is fixed - self.replicate() - # check we're testing what we think we are: no rows should yet have been # received self.assertEqual([], self.test_handler.received_rdata_rows) @@ -327,11 +318,6 @@ class EventsStreamTestCase(BaseStreamTestCase): prev_events = [e.event_id] pl_events.append(e) - # tell the notifier to catch up to avoid duplicate rows. - # workaround for https://github.com/matrix-org/synapse/issues/7360 - # FIXME remove this when the above is fixed - self.replicate() - # check we're testing what we think we are: no rows should yet have been # received self.assertEqual([], self.test_handler.received_rdata_rows) @@ -378,6 +364,64 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual([], received_rows) + def test_backwards_stream_id(self): + """ + Test that RDATA that comes after the current position should be discarded. + """ + # disconnect, so that we can stack up some changes + self.disconnect() + + # Generate an events. We inject them using inject_event so that they are + # not send out over replication until we call self.replicate(). + event = self._inject_test_event() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # We should have received the expected single row (as well as various + # cache invalidation updates which we ignore). + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + + # There should be a single received row. + self.assertEqual(len(received_rows), 1) + + stream_name, token, row = received_rows[0] + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, event.event_id) + + # Reset the data. + self.test_handler.received_rdata_rows = [] + + # Save the current token for later. + worker_events_stream = self.worker_hs.get_replication_streams()["events"] + prev_token = worker_events_stream.current_token("master") + + # Manually send an old RDATA command, which should get dropped. This + # re-uses the row from above, but with an earlier stream token. + self.hs.get_tcp_replication().send_command( + RdataCommand("events", "master", 1, row) + ) + + # No updates have been received (because it was discard as old). + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + self.assertEqual(len(received_rows), 0) + + # Ensure the stream has not gone backwards. + current_token = worker_events_stream.current_token("master") + self.assertGreaterEqual(current_token, prev_token) + event_count = 0 def _inject_test_event( diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index fd62b26356..5acfb3e53e 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -16,10 +16,15 @@ from mock import Mock from synapse.handlers.typing import RoomMember from synapse.replication.tcp.streams import TypingStream +from synapse.util.caches.stream_change_cache import StreamChangeCache from tests.replication._base import BaseStreamTestCase USER_ID = "@feeling:blue" +USER_ID_2 = "@da-ba-dee:blue" + +ROOM_ID = "!bar:blue" +ROOM_ID_2 = "!foo:blue" class TypingStreamTestCase(BaseStreamTestCase): @@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase): def test_typing(self): typing = self.hs.get_typing_handler() - room_id = "!bar:blue" - self.reconnect() - typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) + typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) self.reactor.advance(0) @@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: TypingStream.TypingStreamRow - self.assertEqual(room_id, row.room_id) + self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([USER_ID], row.user_ids) # Now let's disconnect and insert some data. @@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.test_handler.on_rdata.reset_mock() - typing._push_update(member=RoomMember(room_id, USER_ID), typing=False) + typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False) self.test_handler.on_rdata.assert_not_called() @@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] - self.assertEqual(room_id, row.room_id) + self.assertEqual(ROOM_ID, row.room_id) + self.assertEqual([], row.user_ids) + + def test_reset(self): + """ + Test what happens when a typing stream resets. + + This is emulated by jumping the stream ahead, then reconnecting (which + sends the proper position and RDATA). + """ + typing = self.hs.get_typing_handler() + + self.reconnect() + + typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) + + self.reactor.advance(0) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") + + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] # type: TypingStream.TypingStreamRow + self.assertEqual(ROOM_ID, row.room_id) + self.assertEqual([USER_ID], row.user_ids) + + # Push the stream forward a bunch so it can be reset. + for i in range(100): + typing._push_update( + member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True + ) + self.reactor.advance(0) + + # Disconnect. + self.disconnect() + + # Reset the typing handler + self.hs.get_replication_streams()["typing"].last_token = 0 + self.hs.get_tcp_replication()._streams["typing"].last_token = 0 + typing._latest_room_serial = 0 + typing._typing_stream_change_cache = StreamChangeCache( + "TypingStreamChangeCache", typing._latest_room_serial + ) + typing._reset() + + # Reconnect. + self.reconnect() + self.pump(0.1) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") + + # Reset the test code. + self.test_handler.on_rdata.reset_mock() + self.test_handler.on_rdata.assert_not_called() + + # Push additional data. + typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False) + self.reactor.advance(0) + + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] + self.assertEqual(ROOM_ID_2, row.room_id) self.assertEqual([], row.user_ids) + + # The token should have been reset. + self.assertEqual(token, 1) -- cgit 1.5.1 From 2b2344652b215b8023fb37deeacbb395f3c68d7c Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 15 Jun 2020 13:42:44 -0400 Subject: Ensure etag is a string for GET room_keys/version response (#7691) --- changelog.d/7691.bugfix | 1 + synapse/handlers/e2e_room_keys.py | 1 + tests/handlers/test_e2e_room_keys.py | 1 + 3 files changed, 3 insertions(+) create mode 100644 changelog.d/7691.bugfix (limited to 'tests') diff --git a/changelog.d/7691.bugfix b/changelog.d/7691.bugfix new file mode 100644 index 0000000000..2a8a480c53 --- /dev/null +++ b/changelog.d/7691.bugfix @@ -0,0 +1 @@ +Fix a long standing bug where the response to the `GET room_keys/version` endpoint had the incorrect type for the `etag` field. diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 2efea801bc..f55470a707 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -349,6 +349,7 @@ class E2eRoomKeysHandler(object): raise res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) + res["etag"] = str(res["etag"]) return res @trace diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 70f172eb02..822ea42dde 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -96,6 +96,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) version_etag = res["etag"] + self.assertIsInstance(version_etag, str) del res["etag"] self.assertDictEqual( res, -- cgit 1.5.1 From cc32fa7358641b96f5d3dbc14d0cd068e676e256 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 15 Jun 2020 16:20:34 -0400 Subject: Ensure the body is a string before comparing push rules. (#7701) --- changelog.d/7701.bugfix | 1 + synapse/push/push_rule_evaluator.py | 4 ++-- tests/push/test_push_rule_evaluator.py | 39 ++++++++++++++++++++++++++-------- 3 files changed, 33 insertions(+), 11 deletions(-) create mode 100644 changelog.d/7701.bugfix (limited to 'tests') diff --git a/changelog.d/7701.bugfix b/changelog.d/7701.bugfix new file mode 100644 index 0000000000..e5b10f75fd --- /dev/null +++ b/changelog.d/7701.bugfix @@ -0,0 +1 @@ +Do not break push rule evaluation when receiving an event with a non-string body. This is a long-standing bug. diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 11032491af..aeac257a6e 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -131,7 +131,7 @@ class PushRuleEvaluatorForEvent(object): # XXX: optimisation: cache our pattern regexps if condition["key"] == "content.body": body = self._event.content.get("body", None) - if not body: + if not body or not isinstance(body, str): return False return _glob_matches(pattern, body, word_boundary=True) @@ -147,7 +147,7 @@ class PushRuleEvaluatorForEvent(object): return False body = self._event.content.get("body", None) - if not body: + if not body or not isinstance(body, str): return False # Similar to _glob_matches, but do not treat display_name as a glob. diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 9ae6a87d7b..af35d23aea 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -21,7 +21,7 @@ from tests import unittest class PushRuleEvaluatorTestCase(unittest.TestCase): - def setUp(self): + def _get_evaluator(self, content): event = FrozenEvent( { "event_id": "$event_id", @@ -29,37 +29,58 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "sender": "@user:test", "state_key": "", "room_id": "@room:test", - "content": {"body": "foo bar baz"}, + "content": content, }, RoomVersions.V1, ) room_member_count = 0 sender_power_level = 0 power_levels = {} - self.evaluator = PushRuleEvaluatorForEvent( + return PushRuleEvaluatorForEvent( event, room_member_count, sender_power_level, power_levels ) def test_display_name(self): """Check for a matching display name in the body of the event.""" + evaluator = self._get_evaluator({"body": "foo bar baz"}) + condition = { "kind": "contains_display_name", } # Blank names are skipped. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "")) + self.assertFalse(evaluator.matches(condition, "@user:test", "")) # Check a display name that doesn't match. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found")) + self.assertFalse(evaluator.matches(condition, "@user:test", "not found")) # Check a display name which matches. - self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo")) + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) # A display name that matches, but not a full word does not result in a match. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba")) + self.assertFalse(evaluator.matches(condition, "@user:test", "ba")) # A display name should not be interpreted as a regular expression. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]")) + self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]")) # A display name with spaces should work fine. - self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar")) + self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) + + def test_no_body(self): + """Not having a body shouldn't break the evaluator.""" + evaluator = self._get_evaluator({}) + + condition = { + "kind": "contains_display_name", + } + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + def test_invalid_body(self): + """A non-string body should not break the evaluator.""" + condition = { + "kind": "contains_display_name", + } + + for body in (1, True, {"foo": "bar"}): + evaluator = self._get_evaluator({"body": body}) + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) -- cgit 1.5.1 From 3e6b5bba7177274db5533cc5aae0a0f8acf71597 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 16 Jun 2020 10:13:59 +0100 Subject: Wrap register_device coroutine in an ensureDeferred (#7684) Fixes https://github.com/matrix-org/synapse/issues/7683 Broke in: #7649 We had a `yield` acting on a coroutine. To be fair this one is a bit difficult to notice as there's a function in the middle that just passes the coroutine along. --- changelog.d/7684.bugfix | 1 + synapse/module_api/__init__.py | 12 ++++++---- tests/module_api/__init__.py | 0 tests/module_api/test_api.py | 54 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 5 deletions(-) create mode 100644 changelog.d/7684.bugfix create mode 100644 tests/module_api/__init__.py create mode 100644 tests/module_api/test_api.py (limited to 'tests') diff --git a/changelog.d/7684.bugfix b/changelog.d/7684.bugfix new file mode 100644 index 0000000000..a93a92ea8b --- /dev/null +++ b/changelog.d/7684.bugfix @@ -0,0 +1 @@ +Fix a bug that would crash Synapse on start when using certain password auth providers. Broke in release v1.15.0. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index ecdf1ad69f..a7849cefa5 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -126,7 +126,7 @@ class ModuleApi(object): 'errcode' property for more information on the reason for failure Returns: - Deferred[str]: user_id + defer.Deferred[str]: user_id """ return defer.ensureDeferred( self._hs.get_registration_handler().register_user( @@ -149,10 +149,12 @@ class ModuleApi(object): Returns: defer.Deferred[tuple[str, str]]: Tuple of device ID and access token """ - return self._hs.get_registration_handler().register_device( - user_id=user_id, - device_id=device_id, - initial_display_name=initial_display_name, + return defer.ensureDeferred( + self._hs.get_registration_handler().register_device( + user_id=user_id, + device_id=device_id, + initial_display_name=initial_display_name, + ) ) def record_user_external_id( diff --git a/tests/module_api/__init__.py b/tests/module_api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py new file mode 100644 index 0000000000..807cd65dd6 --- /dev/null +++ b/tests/module_api/test_api.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.module_api import ModuleApi + +from tests.unittest import HomeserverTestCase + + +class ModuleApiTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() + self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler()) + + def test_can_register_user(self): + """Tests that an external module can register a user""" + # Register a new user + user_id, access_token = self.get_success( + self.module_api.register( + "bob", displayname="Bobberino", emails=["bob@bobinator.bob"] + ) + ) + + # Check that the new user exists with all provided attributes + self.assertEqual(user_id, "@bob:test") + self.assertTrue(access_token) + self.assertTrue(self.store.get_user_by_id(user_id)) + + # Check that the email was assigned + emails = self.get_success(self.store.user_get_threepids(user_id)) + self.assertEqual(len(emails), 1) + + email = emails[0] + self.assertEqual(email["medium"], "email") + self.assertEqual(email["address"], "bob@bobinator.bob") + + # Should these be 0? + self.assertEqual(email["validated_at"], 0) + self.assertEqual(email["added_at"], 0) + + # Check that the displayname was assigned + displayname = self.get_success(self.store.get_profile_displayname("bob")) + self.assertEqual(displayname, "Bobberino") -- cgit 1.5.1 From 03619324fc18632a2907ace4d3e73f3c4dd0b05e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 16 Jun 2020 12:44:07 +0100 Subject: Create a ListenerConfig object (#7681) This ended up being a bit more invasive than I'd hoped for (not helped by generic_worker duplicating some of the code from homeserver), but hopefully it's an improvement. The idea is that, rather than storing unstructured `dict`s in the config for the listener configurations, we instead parse it into a structured `ListenerConfig` object. --- changelog.d/7681.misc | 1 + synapse/app/_base.py | 8 +- synapse/app/generic_worker.py | 36 +++--- synapse/app/homeserver.py | 50 ++++---- synapse/config/server.py | 235 ++++++++++++++++++++++++-------------- synapse/config/workers.py | 24 ++-- synapse/http/site.py | 6 +- synapse/python_dependencies.py | 5 +- tests/app/test_frontend_proxy.py | 26 ++--- tests/app/test_openid_listener.py | 8 +- tests/test_server.py | 13 ++- tests/unittest.py | 2 +- tests/utils.py | 1 + 13 files changed, 248 insertions(+), 167 deletions(-) create mode 100644 changelog.d/7681.misc (limited to 'tests') diff --git a/changelog.d/7681.misc b/changelog.d/7681.misc new file mode 100644 index 0000000000..e474fc39cd --- /dev/null +++ b/changelog.d/7681.misc @@ -0,0 +1 @@ +Refactor handling of `listeners` configuration settings. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index dedff81af3..373a80a4a7 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -20,6 +20,7 @@ import signal import socket import sys import traceback +from typing import Iterable from daemonize import Daemonize from typing_extensions import NoReturn @@ -29,6 +30,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory import synapse from synapse.app import check_bind_error +from synapse.config.server import ListenerConfig from synapse.crypto import context_factory from synapse.logging.context import PreserveLoggingContext from synapse.util.async_helpers import Linearizer @@ -234,7 +236,7 @@ def refresh_certificate(hs): logger.info("Context factories updated.") -def start(hs, listeners=None): +def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): """ Start a Synapse server or worker. @@ -245,8 +247,8 @@ def start(hs, listeners=None): notify systemd. Args: - hs (synapse.server.HomeServer) - listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml) + hs: homeserver instance + listeners: Listener configuration ('listeners' in homeserver.yaml) """ try: # Set up the SIGHUP machinery. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 53c488d211..27a3fc9ed6 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -37,6 +37,7 @@ from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging +from synapse.config.server import ListenerConfig from synapse.federation import send_queue from synapse.federation.transport.server import TransportLayerServer from synapse.handlers.presence import ( @@ -514,13 +515,18 @@ class GenericWorkerSlavedStore( class GenericWorkerServer(HomeServer): DATASTORE_CLASS = GenericWorkerSlavedStore - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) + def _listen_http(self, listener_config: ListenerConfig): + port = listener_config.port + bind_addresses = listener_config.bind_addresses + + assert listener_config.http_options is not None + + site_tag = listener_config.http_options.tag + if site_tag is None: + site_tag = port resources = {} - for res in listener_config["resources"]: - for name in res["names"]: + for res in listener_config.http_options.resources: + for name in res.names: if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "client": @@ -590,7 +596,7 @@ class GenericWorkerServer(HomeServer): " repository is disabled. Ignoring." ) - if name == "openid" and "federation" not in res["names"]: + if name == "openid" and "federation" not in res.names: # Only load the openid resource separately if federation resource # is not specified since federation resource includes openid # resource. @@ -625,19 +631,19 @@ class GenericWorkerServer(HomeServer): logger.info("Synapse worker now listening on port %d", port) - def start_listening(self, listeners): + def start_listening(self, listeners: Iterable[ListenerConfig]): for listener in listeners: - if listener["type"] == "http": + if listener.type == "http": self._listen_http(listener) - elif listener["type"] == "manhole": + elif listener.type == "manhole": _base.listen_tcp( - listener["bind_addresses"], - listener["port"], + listener.bind_addresses, + listener.port, manhole( username="matrix", password="rabbithole", globals={"hs": self} ), ) - elif listener["type"] == "metrics": + elif listener.type == "metrics": if not self.get_config().enable_metrics: logger.warning( ( @@ -646,9 +652,9 @@ class GenericWorkerServer(HomeServer): ) ) else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) + _base.listen_metrics(listener.bind_addresses, listener.port) else: - logger.warning("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unsupported listener type: %s", listener.type) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 93bc45208e..299134d00f 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -23,6 +23,7 @@ import math import os import resource import sys +from typing import Iterable from prometheus_client import Gauge @@ -48,6 +49,7 @@ from synapse.app import _base from synapse.app._base import listen_ssl, listen_tcp, quit_with_error from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.additional_resource import AdditionalResource from synapse.http.server import ( @@ -87,24 +89,24 @@ def gz_wrap(r): class SynapseHomeServer(HomeServer): DATASTORE_CLASS = DataStore - def _listener_http(self, config, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - tls = listener_config.get("tls", False) - site_tag = listener_config.get("tag", port) + def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig): + port = listener_config.port + bind_addresses = listener_config.bind_addresses + tls = listener_config.tls + site_tag = listener_config.http_options.tag + if site_tag is None: + site_tag = port resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "openid" and "federation" in res["names"]: + for res in listener_config.http_options.resources: + for name in res.names: + if name == "openid" and "federation" in res.names: # Skip loading openid resource if federation is defined # since federation resource will include openid continue - resources.update( - self._configure_named_resource(name, res.get("compress", False)) - ) + resources.update(self._configure_named_resource(name, res.compress)) - additional_resources = listener_config.get("additional_resources", {}) + additional_resources = listener_config.http_options.additional_resources logger.debug("Configuring additional resources: %r", additional_resources) module_api = ModuleApi(self, self.get_auth_handler()) for path, resmodule in additional_resources.items(): @@ -276,7 +278,7 @@ class SynapseHomeServer(HomeServer): return resources - def start_listening(self, listeners): + def start_listening(self, listeners: Iterable[ListenerConfig]): config = self.get_config() if config.redis_enabled: @@ -286,25 +288,25 @@ class SynapseHomeServer(HomeServer): self.get_tcp_replication().start_replication(self) for listener in listeners: - if listener["type"] == "http": + if listener.type == "http": self._listening_services.extend(self._listener_http(config, listener)) - elif listener["type"] == "manhole": + elif listener.type == "manhole": listen_tcp( - listener["bind_addresses"], - listener["port"], + listener.bind_addresses, + listener.port, manhole( username="matrix", password="rabbithole", globals={"hs": self} ), ) - elif listener["type"] == "replication": + elif listener.type == "replication": services = listen_tcp( - listener["bind_addresses"], - listener["port"], + listener.bind_addresses, + listener.port, ReplicationStreamProtocolFactory(self), ) for s in services: reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) - elif listener["type"] == "metrics": + elif listener.type == "metrics": if not self.get_config().enable_metrics: logger.warning( ( @@ -313,9 +315,11 @@ class SynapseHomeServer(HomeServer): ) ) else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) + _base.listen_metrics(listener.bind_addresses, listener.port) else: - logger.warning("Unrecognized listener type: %s", listener["type"]) + # this shouldn't happen, as the listener type should have been checked + # during parsing + logger.warning("Unrecognized listener type: %s", listener.type) # Gauges to expose monthly active user control metrics diff --git a/synapse/config/server.py b/synapse/config/server.py index 73226e63d5..8204664883 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional import attr import yaml @@ -57,6 +57,64 @@ on how to configure the new listener. --------------------------------------------------------------------------------""" +KNOWN_LISTENER_TYPES = { + "http", + "metrics", + "manhole", + "replication", +} + +KNOWN_RESOURCES = { + "client", + "consent", + "federation", + "keys", + "media", + "metrics", + "openid", + "replication", + "static", + "webclient", +} + + +@attr.s(frozen=True) +class HttpResourceConfig: + names = attr.ib( + type=List[str], + factory=list, + validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore + ) + compress = attr.ib( + type=bool, + default=False, + validator=attr.validators.optional(attr.validators.instance_of(bool)), # type: ignore[arg-type] + ) + + +@attr.s(frozen=True) +class HttpListenerConfig: + """Object describing the http-specific parts of the config of a listener""" + + x_forwarded = attr.ib(type=bool, default=False) + resources = attr.ib(type=List[HttpResourceConfig], factory=list) + additional_resources = attr.ib(type=Dict[str, dict], factory=dict) + tag = attr.ib(type=str, default=None) + + +@attr.s(frozen=True) +class ListenerConfig: + """Object describing the configuration of a single listener.""" + + port = attr.ib(type=int, validator=attr.validators.instance_of(int)) + bind_addresses = attr.ib(type=List[str]) + type = attr.ib(type=str, validator=attr.validators.in_(KNOWN_LISTENER_TYPES)) + tls = attr.ib(type=bool, default=False) + + # http_options is only populated if type=http + http_options = attr.ib(type=Optional[HttpListenerConfig], default=None) + + class ServerConfig(Config): section = "server" @@ -379,38 +437,21 @@ class ServerConfig(Config): } ] - self.listeners = [] # type: List[dict] - for listener in config.get("listeners", []): - if not isinstance(listener.get("port", None), int): - raise ConfigError( - "Listener configuration is lacking a valid 'port' option" - ) + self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])] - if listener.setdefault("tls", False): - # no_tls is not really supported any more, but let's grandfather it in - # here. - if config.get("no_tls", False): + # no_tls is not really supported any more, but let's grandfather it in + # here. + if config.get("no_tls", False): + l2 = [] + for listener in self.listeners: + if listener.tls: logger.info( - "Ignoring TLS-enabled listener on port %i due to no_tls" + "Ignoring TLS-enabled listener on port %i due to no_tls", + listener.port, ) - continue - - bind_address = listener.pop("bind_address", None) - bind_addresses = listener.setdefault("bind_addresses", []) - - # if bind_address was specified, add it to the list of addresses - if bind_address: - bind_addresses.append(bind_address) - - # if we still have an empty list of addresses, use the default list - if not bind_addresses: - if listener["type"] == "metrics": - # the metrics listener doesn't support IPv6 - bind_addresses.append("0.0.0.0") else: - bind_addresses.extend(DEFAULT_BIND_ADDRESSES) - - self.listeners.append(listener) + l2.append(listener) + self.listeners = l2 if not self.web_client_location: _warn_if_webclient_configured(self.listeners) @@ -446,43 +487,41 @@ class ServerConfig(Config): bind_host = config.get("bind_host", "") gzip_responses = config.get("gzip_responses", True) + http_options = HttpListenerConfig( + resources=[ + HttpResourceConfig(names=["client"], compress=gzip_responses), + HttpResourceConfig(names=["federation"]), + ], + ) + self.listeners.append( - { - "port": bind_port, - "bind_addresses": [bind_host], - "tls": True, - "type": "http", - "resources": [ - {"names": ["client"], "compress": gzip_responses}, - {"names": ["federation"], "compress": False}, - ], - } + ListenerConfig( + port=bind_port, + bind_addresses=[bind_host], + tls=True, + type="http", + http_options=http_options, + ) ) unsecure_port = config.get("unsecure_port", bind_port - 400) if unsecure_port: self.listeners.append( - { - "port": unsecure_port, - "bind_addresses": [bind_host], - "tls": False, - "type": "http", - "resources": [ - {"names": ["client"], "compress": gzip_responses}, - {"names": ["federation"], "compress": False}, - ], - } + ListenerConfig( + port=unsecure_port, + bind_addresses=[bind_host], + tls=False, + type="http", + http_options=http_options, + ) ) manhole = config.get("manhole") if manhole: self.listeners.append( - { - "port": manhole, - "bind_addresses": ["127.0.0.1"], - "type": "manhole", - "tls": False, - } + ListenerConfig( + port=manhole, bind_addresses=["127.0.0.1"], type="manhole", + ) ) metrics_port = config.get("metrics_port") @@ -490,13 +529,14 @@ class ServerConfig(Config): logger.warning(METRICS_PORT_WARNING) self.listeners.append( - { - "port": metrics_port, - "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], - "tls": False, - "type": "http", - "resources": [{"names": ["metrics"], "compress": False}], - } + ListenerConfig( + port=metrics_port, + bind_addresses=[config.get("metrics_bind_host", "127.0.0.1")], + type="http", + http_options=HttpListenerConfig( + resources=[HttpResourceConfig(names=["metrics"])] + ), + ) ) _check_resource_config(self.listeners) @@ -522,7 +562,7 @@ class ServerConfig(Config): ) def has_tls_listener(self) -> bool: - return any(listener["tls"] for listener in self.listeners) + return any(listener.tls for listener in self.listeners) def generate_config_section( self, server_name, data_dir_path, open_private_ports, listeners, **kwargs @@ -1081,6 +1121,44 @@ def read_gc_thresholds(thresholds): ) +def parse_listener_def(listener: Any) -> ListenerConfig: + """parse a listener config from the config file""" + listener_type = listener["type"] + + port = listener.get("port") + if not isinstance(port, int): + raise ConfigError("Listener configuration is lacking a valid 'port' option") + + tls = listener.get("tls", False) + + bind_addresses = listener.get("bind_addresses", []) + bind_address = listener.get("bind_address") + # if bind_address was specified, add it to the list of addresses + if bind_address: + bind_addresses.append(bind_address) + + # if we still have an empty list of addresses, use the default list + if not bind_addresses: + if listener_type == "metrics": + # the metrics listener doesn't support IPv6 + bind_addresses.append("0.0.0.0") + else: + bind_addresses.extend(DEFAULT_BIND_ADDRESSES) + + http_config = None + if listener_type == "http": + http_config = HttpListenerConfig( + x_forwarded=listener.get("x_forwarded", False), + resources=[ + HttpResourceConfig(**res) for res in listener.get("resources", []) + ], + additional_resources=listener.get("additional_resources", {}), + tag=listener.get("tag"), + ) + + return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) + + NO_MORE_WEB_CLIENT_WARNING = """ Synapse no longer includes a web client. To enable a web client, configure web_client_location. To remove this warning, remove 'webclient' from the 'listeners' @@ -1088,40 +1166,27 @@ configuration. """ -def _warn_if_webclient_configured(listeners): +def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None: for listener in listeners: - for res in listener.get("resources", []): - for name in res.get("names", []): + if not listener.http_options: + continue + for res in listener.http_options.resources: + for name in res.names: if name == "webclient": logger.warning(NO_MORE_WEB_CLIENT_WARNING) return -KNOWN_RESOURCES = ( - "client", - "consent", - "federation", - "keys", - "media", - "metrics", - "openid", - "replication", - "static", - "webclient", -) - - -def _check_resource_config(listeners): +def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None: resource_names = { res_name for listener in listeners - for res in listener.get("resources", []) - for res_name in res.get("names", []) + if listener.http_options + for res in listener.http_options.resources + for res_name in res.names } for resource in resource_names: - if resource not in KNOWN_RESOURCES: - raise ConfigError("Unknown listener resource '%s'" % (resource,)) if resource == "consent": try: check_requirements("resources.consent") diff --git a/synapse/config/workers.py b/synapse/config/workers.py index ed06b91a54..dbc661630c 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -16,6 +16,7 @@ import attr from ._base import Config, ConfigError +from .server import ListenerConfig, parse_listener_def @attr.s @@ -52,7 +53,9 @@ class WorkerConfig(Config): if self.worker_app == "synapse.app.homeserver": self.worker_app = None - self.worker_listeners = config.get("worker_listeners", []) + self.worker_listeners = [ + parse_listener_def(x) for x in config.get("worker_listeners", []) + ] self.worker_daemonize = config.get("worker_daemonize") self.worker_pid_file = config.get("worker_pid_file") self.worker_log_config = config.get("worker_log_config") @@ -75,24 +78,11 @@ class WorkerConfig(Config): manhole = config.get("worker_manhole") if manhole: self.worker_listeners.append( - { - "port": manhole, - "bind_addresses": ["127.0.0.1"], - "type": "manhole", - "tls": False, - } + ListenerConfig( + port=manhole, bind_addresses=["127.0.0.1"], type="manhole", + ) ) - if self.worker_listeners: - for listener in self.worker_listeners: - bind_address = listener.pop("bind_address", None) - bind_addresses = listener.setdefault("bind_addresses", []) - - if bind_address: - bind_addresses.append(bind_address) - elif not bind_addresses: - bind_addresses.append("") - # A map from instance name to host/port of their HTTP replication endpoint. instance_map = config.get("instance_map") or {} self.instance_map = { diff --git a/synapse/http/site.py b/synapse/http/site.py index 167293c46d..cbc37eac6e 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -19,6 +19,7 @@ from typing import Optional from twisted.python.failure import Failure from twisted.web.server import Request, Site +from synapse.config.server import ListenerConfig from synapse.http import redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import LoggingContext, PreserveLoggingContext @@ -350,7 +351,7 @@ class SynapseSite(Site): self, logger_name, site_tag, - config, + config: ListenerConfig, resource, server_version_string, *args, @@ -360,7 +361,8 @@ class SynapseSite(Site): self.site_tag = site_tag - proxied = config.get("x_forwarded", False) + assert config.http_options is not None + proxied = config.http_options.x_forwarded self.requestFactory = XForwardedForRequest if proxied else SynapseRequest self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8b4312e5a3..8ec1a619a2 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -68,9 +68,8 @@ REQUIREMENTS = [ "phonenumbers>=8.2.0", "six>=1.10", "prometheus_client>=0.0.18,<0.8.0", - # we use attr.s(slots), which arrived in 16.0.0 - # Twisted 18.7.0 requires attrs>=17.4.0 - "attrs>=17.4.0", + # we use attr.validators.deep_iterable, which arrived in 19.1.0 + "attrs>=19.1.0", "netaddr>=0.7.18", "Jinja2>=2.9", "bleach>=1.4.3", diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index be20a89682..641093d349 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -30,6 +30,16 @@ class FrontendProxyTests(HomeserverTestCase): def default_config(self): c = super().default_config() c["worker_app"] = "synapse.app.frontend_proxy" + + c["worker_listeners"] = [ + { + "type": "http", + "port": 8080, + "bind_addresses": ["0.0.0.0"], + "resources": [{"names": ["client"]}], + } + ] + return c def test_listen_http_with_presence_enabled(self): @@ -39,14 +49,8 @@ class FrontendProxyTests(HomeserverTestCase): # Presence is on self.hs.config.use_presence = True - config = { - "port": 8080, - "bind_addresses": ["0.0.0.0"], - "resources": [{"names": ["client"]}], - } - # Listen with the config - self.hs._listen_http(config) + self.hs._listen_http(self.hs.config.worker.worker_listeners[0]) # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) @@ -67,14 +71,8 @@ class FrontendProxyTests(HomeserverTestCase): # Presence is off self.hs.config.use_presence = False - config = { - "port": 8080, - "bind_addresses": ["0.0.0.0"], - "resources": [{"names": ["client"]}], - } - # Listen with the config - self.hs._listen_http(config) + self.hs._listen_http(self.hs.config.worker.worker_listeners[0]) # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 7364f9f1ec..0f016c32eb 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -18,6 +18,7 @@ from parameterized import parameterized from synapse.app.generic_worker import GenericWorkerServer from synapse.app.homeserver import SynapseHomeServer +from synapse.config.server import parse_listener_def from tests.unittest import HomeserverTestCase @@ -35,6 +36,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): # have to tell the FederationHandler not to try to access stuff that is only # in the primary store. conf["worker_app"] = "yes" + return conf @parameterized.expand( @@ -53,12 +55,13 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): """ config = { "port": 8080, + "type": "http", "bind_addresses": ["0.0.0.0"], "resources": [{"names": names}], } # Listen with the config - self.hs._listen_http(config) + self.hs._listen_http(parse_listener_def(config)) # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] @@ -101,12 +104,13 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): """ config = { "port": 8080, + "type": "http", "bind_addresses": ["0.0.0.0"], "resources": [{"names": names}], } # Listen with the config - self.hs._listener_http(config, config) + self.hs._listener_http(self.hs.get_config(), parse_listener_def(config)) # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] diff --git a/tests/test_server.py b/tests/test_server.py index e9a43b1e45..adae3c6e08 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -24,6 +24,7 @@ from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.config.server import parse_listener_def from synapse.http.server import ( DirectServeResource, JsonResource, @@ -189,7 +190,13 @@ class OptionsResourceTests(unittest.TestCase): request.prepath = [] # This doesn't get set properly by make_request. # Create a site and query for the resource. - site = SynapseSite("test", "site_tag", {}, self.resource, "1.0") + site = SynapseSite( + "test", + "site_tag", + parse_listener_def({"type": "http", "port": 0}), + self.resource, + "1.0", + ) request.site = site resource = site.getResourceFor(request) @@ -348,7 +355,9 @@ class SiteTestCase(unittest.HomeserverTestCase): # time out the request while it's 'processing' base_resource = Resource() base_resource.putChild(b"", HangingResource()) - site = SynapseSite("test", "site_tag", {}, base_resource, "1.0") + site = SynapseSite( + "test", "site_tag", self.hs.config.listeners[0], base_resource, "1.0" + ) server = site.buildProtocol(None) client = AccumulatingProtocol() diff --git a/tests/unittest.py b/tests/unittest.py index 6b6f224e9c..3175a3fa02 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -229,7 +229,7 @@ class HomeserverTestCase(TestCase): self.site = SynapseSite( logger_name="synapse.access.http.fake", site_tag="test", - config={}, + config=self.hs.config.server.listeners[0], resource=self.resource, server_version_string="1", ) diff --git a/tests/utils.py b/tests/utils.py index 59c020a051..7ba8a31ff3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -168,6 +168,7 @@ def default_config(name, parse=False): # background, which upsets the test runner. "update_user_directory": False, "caches": {"global_factor": 1}, + "listeners": [{"port": 0, "type": "http"}], } if parse: -- cgit 1.5.1 From a3f11567d930b7da0db068c3b313f6f4abbf12a1 Mon Sep 17 00:00:00 2001 From: Dagfinn Ilmari Mannsåker Date: Tue, 16 Jun 2020 13:51:47 +0100 Subject: Replace all remaining six usage with native Python 3 equivalents (#7704) --- changelog.d/7704.misc | 1 + contrib/graph/graph3.py | 4 +--- scripts-dev/federation_client.py | 3 +-- scripts/synapse_port_db | 4 +--- setup.cfg | 2 +- synapse/_scripts/register_new_matrix_user.py | 2 -- synapse/api/errors.py | 7 +++---- synapse/api/filtering.py | 4 +--- synapse/api/urls.py | 3 +-- synapse/appservice/__init__.py | 4 +--- synapse/appservice/api.py | 3 +-- synapse/config/_base.py | 6 ++---- synapse/config/appservice.py | 13 ++++--------- synapse/config/tls.py | 4 +--- synapse/crypto/keyring.py | 6 ++---- synapse/events/utils.py | 4 +--- synapse/events/validator.py | 12 +++++------- synapse/federation/federation_base.py | 4 +--- synapse/federation/federation_server.py | 4 +--- synapse/federation/transport/client.py | 3 +-- synapse/groups/groups_server.py | 4 +--- synapse/handlers/cas_handler.py | 3 +-- synapse/handlers/federation.py | 9 ++++----- synapse/handlers/message.py | 4 +--- synapse/handlers/profile.py | 8 +++----- synapse/handlers/room.py | 4 +--- synapse/handlers/room_member.py | 7 +++---- synapse/http/client.py | 8 +++----- synapse/http/matrixfederationclient.py | 12 +++++------- synapse/http/server.py | 4 ++-- synapse/logging/formatter.py | 3 +-- synapse/push/mailer.py | 3 +-- synapse/push/push_rule_evaluator.py | 4 +--- synapse/python_dependencies.py | 1 - synapse/replication/http/_base.py | 6 ++---- synapse/rest/admin/users.py | 20 ++++++-------------- synapse/rest/client/v1/presence.py | 4 +--- synapse/rest/client/v1/room.py | 3 +-- synapse/rest/client/v2_alpha/account.py | 5 ++--- synapse/rest/client/v2_alpha/register.py | 11 +++-------- synapse/rest/client/v2_alpha/report_event.py | 10 ++++------ synapse/rest/consent/consent_resource.py | 5 ++--- synapse/rest/media/v1/_base.py | 3 +-- synapse/rest/media/v1/media_storage.py | 6 +----- synapse/rest/media/v1/preview_url_resource.py | 9 +++------ synapse/server_notices/consent_server_notices.py | 4 +--- synapse/storage/data_stores/main/event_federation.py | 3 +-- synapse/storage/data_stores/main/events.py | 10 +++------- .../storage/data_stores/main/events_bg_updates.py | 4 +--- .../data_stores/main/schema/delta/30/as_users.py | 2 -- synapse/storage/data_stores/main/search.py | 4 +--- synapse/storage/data_stores/main/stream.py | 2 -- synapse/storage/data_stores/main/tags.py | 2 -- synapse/storage/data_stores/state/store.py | 2 -- synapse/storage/database.py | 3 +-- synapse/storage/persist_events.py | 2 -- synapse/util/async_helpers.py | 2 -- synapse/util/caches/stream_change_cache.py | 4 +--- synapse/util/file_consumer.py | 2 +- synapse/util/frozenutils.py | 6 ++---- synapse/util/wheel_timer.py | 2 -- synapse/visibility.py | 2 -- synctl | 6 ++---- tests/rest/client/v1/test_rooms.py | 2 +- tests/rest/client/v2_alpha/test_relations.py | 9 ++++----- tests/rest/media/v1/test_media_storage.py | 2 +- tests/server.py | 4 +--- tests/state/test_v2.py | 2 -- tests/test_server.py | 3 +-- tests/test_terms_auth.py | 9 ++++----- tests/util/test_file_consumer.py | 2 +- tests/util/test_linearizer.py | 2 -- tests/utils.py | 2 +- 73 files changed, 111 insertions(+), 237 deletions(-) create mode 100644 changelog.d/7704.misc (limited to 'tests') diff --git a/changelog.d/7704.misc b/changelog.d/7704.misc new file mode 100644 index 0000000000..7838a613c8 --- /dev/null +++ b/changelog.d/7704.misc @@ -0,0 +1 @@ +Replace all remaining uses of `six` with native Python 3 equivalents. Contributed by @ilmari. diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py index 7f9e5374a6..3154638520 100644 --- a/contrib/graph/graph3.py +++ b/contrib/graph/graph3.py @@ -24,8 +24,6 @@ import argparse from synapse.events import FrozenEvent from synapse.util.frozenutils import unfreeze -from six import string_types - def make_graph(file_name, room_id, file_prefix, limit): print("Reading lines") @@ -62,7 +60,7 @@ def make_graph(file_name, room_id, file_prefix, limit): for key, value in unfreeze(event.get_dict()["content"]).items(): if value is None: value = "" - elif isinstance(value, string_types): + elif isinstance(value, str): pass else: value = json.dumps(value) diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 7c19e405d4..531010185d 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -21,8 +21,7 @@ import argparse import base64 import json import sys - -from six.moves.urllib import parse as urlparse +from urllib import parse as urlparse import nacl.signing import requests diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 9a0fbc61d8..a0d81c77c2 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -23,8 +23,6 @@ import sys import time import traceback -from six import string_types - import yaml from twisted.internet import defer, reactor @@ -635,7 +633,7 @@ class Porter(object): return bool(col) if isinstance(col, bytes): return bytearray(col) - elif isinstance(col, string_types) and "\0" in col: + elif isinstance(col, str) and "\0" in col: logger.warning( "DROPPING ROW: NUL value in table %s col %s: %r", table, diff --git a/setup.cfg b/setup.cfg index 12a7849081..f2bca272e1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,7 @@ sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER default_section=THIRDPARTY known_first_party = synapse known_tests=tests -known_compat = mock,six +known_compat = mock known_twisted=twisted,OpenSSL multi_line_output=3 include_trailing_comma=true diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index d528450c78..55cce2db22 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -23,8 +23,6 @@ import hmac import logging import sys -from six.moves import input - import requests as _requests import yaml diff --git a/synapse/api/errors.py b/synapse/api/errors.py index a07a54580d..5305038c21 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -17,10 +17,9 @@ """Contains exceptions and error codes.""" import logging +from http import HTTPStatus from typing import Dict, List -from six.moves import http_client - from canonicaljson import json from twisted.web import http @@ -173,7 +172,7 @@ class ConsentNotGivenError(SynapseError): consent_url (str): The URL where the user can give their consent """ super(ConsentNotGivenError, self).__init__( - code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN + code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN ) self._consent_uri = consent_uri @@ -193,7 +192,7 @@ class UserDeactivatedError(SynapseError): msg (str): The human-readable error message """ super(UserDeactivatedError, self).__init__( - code=http_client.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED + code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED ) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 8b64d0a285..f988f62a1e 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -17,8 +17,6 @@ # limitations under the License. from typing import List -from six import text_type - import jsonschema from canonicaljson import json from jsonschema import FormatChecker @@ -313,7 +311,7 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes - contains_url = isinstance(content.get("url"), text_type) + contains_url = isinstance(content.get("url"), str) labels = content.get(EventContentFields.LABELS, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) diff --git a/synapse/api/urls.py b/synapse/api/urls.py index f34434bd67..bd03ebca5a 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -17,8 +17,7 @@ """Contains the URL paths to prefix various aspects of the server with. """ import hmac from hashlib import sha256 - -from six.moves.urllib.parse import urlencode +from urllib.parse import urlencode from synapse.config import ConfigError diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 1b13e84425..0323256472 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -15,8 +15,6 @@ import logging import re -from six import string_types - from twisted.internet import defer from synapse.api.constants import EventTypes @@ -156,7 +154,7 @@ class ApplicationService(object): ) regex = regex_obj.get("regex") - if isinstance(regex, string_types): + if isinstance(regex, str): regex_obj["regex"] = re.compile(regex) # Pre-compile regex else: raise ValueError("Expected string for 'regex' in ns '%s'" % ns) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 57174da021..da9a5e86d4 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from six.moves import urllib +import urllib from prometheus_client import Counter diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 30d1050a91..1391e5fc43 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -22,8 +22,6 @@ from collections import OrderedDict from textwrap import dedent from typing import Any, MutableMapping, Optional -from six import integer_types - import yaml @@ -117,7 +115,7 @@ class Config(object): @staticmethod def parse_size(value): - if isinstance(value, integer_types): + if isinstance(value, int): return value sizes = {"K": 1024, "M": 1024 * 1024} size = 1 @@ -129,7 +127,7 @@ class Config(object): @staticmethod def parse_duration(value): - if isinstance(value, integer_types): + if isinstance(value, int): return value second = 1000 minute = 60 * second diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index ca43e96bd1..8ed3e24258 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -14,9 +14,7 @@ import logging from typing import Dict - -from six import string_types -from six.moves.urllib import parse as urlparse +from urllib import parse as urlparse import yaml from netaddr import IPSet @@ -98,17 +96,14 @@ def load_appservices(hostname, config_files): def _load_appservice(hostname, as_info, config_filename): required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"] for field in required_string_fields: - if not isinstance(as_info.get(field), string_types): + if not isinstance(as_info.get(field), str): raise KeyError( "Required string field: '%s' (%s)" % (field, config_filename) ) # 'url' must either be a string or explicitly null, not missing # to avoid accidentally turning off push for ASes. - if ( - not isinstance(as_info.get("url"), string_types) - and as_info.get("url", "") is not None - ): + if not isinstance(as_info.get("url"), str) and as_info.get("url", "") is not None: raise KeyError( "Required string field or explicit null: 'url' (%s)" % (config_filename,) ) @@ -138,7 +133,7 @@ def _load_appservice(hostname, as_info, config_filename): ns, regex_obj, ) - if not isinstance(regex_obj.get("regex"), string_types): + if not isinstance(regex_obj.get("regex"), str): raise ValueError("Missing/bad type 'regex' key in %s", regex_obj) if not isinstance(regex_obj.get("exclusive"), bool): raise ValueError( diff --git a/synapse/config/tls.py b/synapse/config/tls.py index a65538562b..e368ea564d 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -20,8 +20,6 @@ from datetime import datetime from hashlib import sha256 from typing import List -import six - from unpaddedbase64 import encode_base64 from OpenSSL import SSL, crypto @@ -59,7 +57,7 @@ class TlsConfig(Config): logger.warning(ACME_SUPPORT_ENABLED_WARN) # hyperlink complains on py2 if this is not a Unicode - self.acme_url = six.text_type( + self.acme_url = str( acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory") ) self.acme_port = acme_config.get("port", 80) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index a9f4025bfe..dbfc3e8972 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -15,11 +15,9 @@ # limitations under the License. import logging +import urllib from collections import defaultdict -import six -from six.moves import urllib - import attr from signedjson.key import ( decode_verify_key_bytes, @@ -661,7 +659,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): for response in query_response["server_keys"]: # do this first, so that we can give useful errors thereafter server_name = response.get("server_name") - if not isinstance(server_name, six.string_types): + if not isinstance(server_name, str): raise KeyLookupError( "Malformed response from key notary server %s: invalid server_name" % (perspective_name,) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index dd340be9a7..f6b507977f 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -16,8 +16,6 @@ import collections import re from typing import Any, Mapping, Union -from six import string_types - from frozendict import frozendict from twisted.internet import defer @@ -318,7 +316,7 @@ def serialize_event( if only_event_fields: if not isinstance(only_event_fields, list) or not all( - isinstance(f, string_types) for f in only_event_fields + isinstance(f, str) for f in only_event_fields ): raise TypeError("only_event_fields must be a list of strings") d = only_fields(d, only_event_fields) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index b001c64bb4..588d222f36 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import integer_types, string_types - from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions @@ -53,7 +51,7 @@ class EventValidator(object): event_strings = ["origin"] for s in event_strings: - if not isinstance(getattr(event, s), string_types): + if not isinstance(getattr(event, s), str): raise SynapseError(400, "'%s' not a string type" % (s,)) # Depending on the room version, ensure the data is spec compliant JSON. @@ -90,7 +88,7 @@ class EventValidator(object): max_lifetime = event.content.get("max_lifetime") if min_lifetime is not None: - if not isinstance(min_lifetime, integer_types): + if not isinstance(min_lifetime, int): raise SynapseError( code=400, msg="'min_lifetime' must be an integer", @@ -124,7 +122,7 @@ class EventValidator(object): ) if max_lifetime is not None: - if not isinstance(max_lifetime, integer_types): + if not isinstance(max_lifetime, int): raise SynapseError( code=400, msg="'max_lifetime' must be an integer", @@ -183,7 +181,7 @@ class EventValidator(object): strings.append("state_key") for s in strings: - if not isinstance(getattr(event, s), string_types): + if not isinstance(getattr(event, s), str): raise SynapseError(400, "Not '%s' a string type" % (s,)) RoomID.from_string(event.room_id) @@ -223,7 +221,7 @@ class EventValidator(object): for s in keys: if s not in d: raise SynapseError(400, "'%s' not in content" % (s,)) - if not isinstance(d[s], string_types): + if not isinstance(d[s], str): raise SynapseError(400, "'%s' not a string type" % (s,)) def _ensure_state_event(self, event): diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index b2ab5bd6a4..420df2385f 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -17,8 +17,6 @@ import logging from collections import namedtuple from typing import Iterable, List -import six - from twisted.internet import defer from twisted.internet.defer import Deferred, DeferredList from twisted.python.failure import Failure @@ -294,7 +292,7 @@ def event_from_pdu_json( assert_params_in_dict(pdu_json, ("type", "depth")) depth = pdu_json["depth"] - if not isinstance(depth, six.integer_types): + if not isinstance(depth, int): raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 6920c23723..afe0a8238b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -17,8 +17,6 @@ import logging from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union -import six - from canonicaljson import json from prometheus_client import Counter @@ -751,7 +749,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: - if not isinstance(acl_entry, six.string_types): + if not isinstance(acl_entry, str): logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 060bf07197..9f99311419 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,10 +15,9 @@ # limitations under the License. import logging +import urllib from typing import Any, Dict, Optional -from six.moves import urllib - from twisted.internet import defer from synapse.api.constants import Membership diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 8a9de913b3..8db8ab1b7b 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -17,8 +17,6 @@ import logging -from six import string_types - from synapse.api.errors import Codes, SynapseError from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute @@ -513,7 +511,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): for keyname in ("name", "avatar_url", "short_description", "long_description"): if keyname in content: value = content[keyname] - if not isinstance(value, string_types): + if not isinstance(value, str): raise SynapseError(400, "%r value is not a string" % (keyname,)) profile[keyname] = value diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 64aaa1335c..76f213723a 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -14,11 +14,10 @@ # limitations under the License. import logging +import urllib import xml.etree.ElementTree as ET from typing import Dict, Optional, Tuple -from six.moves import urllib - from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d6038d9995..873f6bc39f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,10 +19,9 @@ import itertools import logging +from http import HTTPStatus from typing import Dict, Iterable, List, Optional, Sequence, Tuple -from six.moves import http_client, zip - import attr from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json @@ -1194,7 +1193,7 @@ class FederationHandler(BaseHandler): ev.event_id, len(ev.prev_event_ids()), ) - raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: logger.warning( @@ -1202,7 +1201,7 @@ class FederationHandler(BaseHandler): ev.event_id, len(ev.auth_event_ids()), ) - raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. @@ -1545,7 +1544,7 @@ class FederationHandler(BaseHandler): # block any attempts to invite the server notices mxid if event.state_key == self._server_notices_mxid: - raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 354da9a3b5..200127d291 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -17,8 +17,6 @@ import logging from typing import Optional, Tuple -from six import string_types - from canonicaljson import encode_canonical_json, json from twisted.internet import defer @@ -715,7 +713,7 @@ class EventCreationHandler(object): spam_error = self.spam_checker.check_event_for_spam(event) if spam_error: - if not isinstance(spam_error, string_types): + if not isinstance(spam_error, str): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 302efc1b9a..4b1e3073a8 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -15,8 +15,6 @@ import logging -from six import raise_from - from twisted.internet import defer from synapse.api.errors import ( @@ -84,7 +82,7 @@ class BaseProfileHandler(BaseHandler): ) return result except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() @@ -135,7 +133,7 @@ class BaseProfileHandler(BaseHandler): ignore_backoff=True, ) except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() @@ -212,7 +210,7 @@ class BaseProfileHandler(BaseHandler): ignore_backoff=True, ) except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f7401373ca..950a84acd0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -24,8 +24,6 @@ import string from collections import OrderedDict from typing import Tuple -from six import string_types - from synapse.api.constants import ( EventTypes, JoinRules, @@ -595,7 +593,7 @@ class RoomCreationHandler(BaseHandler): "room_version", self.config.default_room_version.identifier ) - if not isinstance(room_version_id, string_types): + if not isinstance(room_version_id, str): raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON) room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0f7af982f0..27c479da9e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,10 +17,9 @@ import abc import logging +from http import HTTPStatus from typing import Dict, Iterable, List, Optional, Tuple -from six.moves import http_client - from synapse import types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError @@ -361,7 +360,7 @@ class RoomMemberHandler(object): if effective_membership_state == Membership.INVITE: # block any attempts to invite the server notices mxid if target.to_string() == self._server_notices_mxid: - raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -444,7 +443,7 @@ class RoomMemberHandler(object): is_blocked = await self._is_server_notice_room(room_id) if is_blocked: raise SynapseError( - http_client.FORBIDDEN, + HTTPStatus.FORBIDDEN, "You cannot reject this invite", errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM, ) diff --git a/synapse/http/client.py b/synapse/http/client.py index 3cef747a4d..8743e9839d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -15,11 +15,9 @@ # limitations under the License. import logging +import urllib from io import BytesIO -from six import raise_from, text_type -from six.moves import urllib - import treq from canonicaljson import encode_canonical_json, json from netaddr import IPAddress @@ -577,7 +575,7 @@ class SimpleHttpClient(object): # This can happen e.g. because the body is too large. raise except Exception as e: - raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e) + raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e return ( length, @@ -638,7 +636,7 @@ def encode_urlencode_args(args): def encode_urlencode_arg(arg): - if isinstance(arg, text_type): + if isinstance(arg, str): return arg.encode("utf-8") elif isinstance(arg, list): return [encode_urlencode_arg(i) for i in arg] diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 2d47b9ea00..7b33b9f10a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -17,11 +17,9 @@ import cgi import logging import random import sys +import urllib from io import BytesIO -from six import raise_from, string_types -from six.moves import urllib - import attr import treq from canonicaljson import encode_canonical_json @@ -432,10 +430,10 @@ class MatrixFederationHttpClient(object): except TimeoutError as e: raise RequestSendFailed(e, can_retry=True) from e except DNSLookupError as e: - raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e) + raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e except Exception as e: logger.info("Failed to send request: %s", e) - raise_from(RequestSendFailed(e, can_retry=True), e) + raise RequestSendFailed(e, can_retry=True) from e incoming_responses_counter.labels( request.method, response.code @@ -487,7 +485,7 @@ class MatrixFederationHttpClient(object): # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException if response.code == 429: - raise_from(RequestSendFailed(e, can_retry=True), e) + raise RequestSendFailed(e, can_retry=True) from e else: raise e @@ -998,7 +996,7 @@ def encode_query_args(args): encoded_args = {} for k, vs in args.items(): - if isinstance(vs, string_types): + if isinstance(vs, str): vs = [vs] encoded_args[k] = [v.encode("UTF-8") for v in vs] diff --git a/synapse/http/server.py b/synapse/http/server.py index 2487a72171..6aa1dc1f92 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -16,10 +16,10 @@ import collections import html -import http.client import logging import types import urllib +from http import HTTPStatus from io import BytesIO from typing import Awaitable, Callable, TypeVar, Union @@ -188,7 +188,7 @@ def return_html_error( exc_info=(f.type, f.value, f.getTracebackObject()), ) else: - code = http.HTTPStatus.INTERNAL_SERVER_ERROR + code = HTTPStatus.INTERNAL_SERVER_ERROR msg = "Internal server error" logger.error( diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py index fbf570c756..d736ad5b9b 100644 --- a/synapse/logging/formatter.py +++ b/synapse/logging/formatter.py @@ -16,8 +16,7 @@ import logging import traceback - -from six import StringIO +from io import StringIO class LogFormatter(logging.Formatter): diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index d57a66a697..dda560b2c2 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -17,12 +17,11 @@ import email.mime.multipart import email.utils import logging import time +import urllib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import Iterable, List, TypeVar -from six.moves import urllib - import bleach import jinja2 diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index aeac257a6e..8e0d3a416d 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -18,8 +18,6 @@ import logging import re from typing import Pattern -from six import string_types - from synapse.events import EventBase from synapse.types import UserID from synapse.util.caches import register_cache @@ -244,7 +242,7 @@ def _flatten_dict(d, prefix=[], result=None): if result is None: result = {} for key, value in d.items(): - if isinstance(value, string_types): + if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() elif hasattr(value, "items"): _flatten_dict(value, prefix=(prefix + [key]), result=result) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8ec1a619a2..d655aba35c 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -66,7 +66,6 @@ REQUIREMENTS = [ "pymacaroons>=0.13.0", "msgpack>=0.5.2", "phonenumbers>=8.2.0", - "six>=1.10", "prometheus_client>=0.0.18,<0.8.0", # we use attr.validators.deep_iterable, which arrived in 19.1.0 "attrs>=19.1.0", diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 793cef6c26..9caf1e80c1 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,12 +16,10 @@ import abc import logging import re +import urllib from inspect import signature from typing import Dict, List, Tuple -from six import raise_from -from six.moves import urllib - from twisted.internet import defer from synapse.api.errors import ( @@ -220,7 +218,7 @@ class ReplicationEndpoint(object): # importantly, not stack traces everywhere) raise e.to_synapse_error() except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to talk to master"), e) + raise SynapseError(502, "Failed to talk to master") from e return result diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index fefc8f71fa..e4330c39d6 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -16,9 +16,7 @@ import hashlib import hmac import logging import re - -from six import text_type -from six.moves import http_client +from http import HTTPStatus from synapse.api.constants import UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -215,10 +213,7 @@ class UserRestServletV2(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) if "password" in body: - if ( - not isinstance(body["password"], text_type) - or len(body["password"]) > 512 - ): + if not isinstance(body["password"], str) or len(body["password"]) > 512: raise SynapseError(400, "Invalid password") else: new_password = body["password"] @@ -252,7 +247,7 @@ class UserRestServletV2(RestServlet): password = body.get("password") password_hash = None if password is not None: - if not isinstance(password, text_type) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") password_hash = await self.auth_handler.hash(password) @@ -370,10 +365,7 @@ class UserRegisterServlet(RestServlet): 400, "username must be specified", errcode=Codes.BAD_JSON ) else: - if ( - not isinstance(body["username"], text_type) - or len(body["username"]) > 512 - ): + if not isinstance(body["username"], str) or len(body["username"]) > 512: raise SynapseError(400, "Invalid username") username = body["username"].encode("utf-8") @@ -386,7 +378,7 @@ class UserRegisterServlet(RestServlet): ) else: password = body["password"] - if not isinstance(password, text_type) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") password_bytes = password.encode("utf-8") @@ -477,7 +469,7 @@ class DeactivateAccountRestServlet(RestServlet): erase = body.get("erase", False) if not isinstance(erase, bool): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'erase' must be a boolean, if given", Codes.BAD_JSON, ) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 7cf007d35e..970fdd5834 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -17,8 +17,6 @@ """ import logging -from six import string_types - from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -73,7 +71,7 @@ class PresenceStatusRestServlet(RestServlet): if "status_msg" in content: state["status_msg"] = content.pop("status_msg") - if not isinstance(state["status_msg"], string_types): + if not isinstance(state["status_msg"], str): raise SynapseError(400, "status_msg must be a string.") if content: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 105e0cf4d2..46811abbfa 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -18,8 +18,7 @@ import logging import re from typing import List, Optional - -from six.moves.urllib import parse as urlparse +from urllib import parse as urlparse from canonicaljson import json diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 1dc4a3247f..923bcb9f85 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -15,8 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from six.moves import http_client +from http import HTTPStatus from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, ThreepidValidationError @@ -321,7 +320,7 @@ class DeactivateAccountRestServlet(RestServlet): erase = body.get("erase", False) if not isinstance(erase, bool): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'erase' must be a boolean, if given", Codes.BAD_JSON, ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b9ffe86b2a..141a3f5fac 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -18,8 +18,6 @@ import hmac import logging from typing import List, Union -from six import string_types - import synapse import synapse.api.auth import synapse.types @@ -413,7 +411,7 @@ class RegisterRestServlet(RestServlet): # in sessions. Pull out the username/password provided to us. if "password" in body: password = body.pop("password") - if not isinstance(password, string_types) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") self.password_policy_handler.validate_password(password) @@ -425,10 +423,7 @@ class RegisterRestServlet(RestServlet): desired_username = None if "username" in body: - if ( - not isinstance(body["username"], string_types) - or len(body["username"]) > 512 - ): + if not isinstance(body["username"], str) or len(body["username"]) > 512: raise SynapseError(400, "Invalid username") desired_username = body["username"] @@ -453,7 +448,7 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) - if isinstance(desired_username, string_types): + if isinstance(desired_username, str): result = await self._do_appservice_registration( desired_username, access_token, body ) diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index f067b5edac..e15927c4ea 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -14,9 +14,7 @@ # limitations under the License. import logging - -from six import string_types -from six.moves import http_client +from http import HTTPStatus from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( @@ -47,15 +45,15 @@ class ReportEventRestServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ("reason", "score")) - if not isinstance(body["reason"], string_types): + if not isinstance(body["reason"], str): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'reason' must be a string", Codes.BAD_JSON, ) if not isinstance(body["score"], int): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", Codes.BAD_JSON, ) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 1ddf9997ff..049c16b236 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -16,10 +16,9 @@ import hmac import logging from hashlib import sha256 +from http import HTTPStatus from os import path -from six.moves import http_client - import jinja2 from jinja2 import TemplateNotFound @@ -223,4 +222,4 @@ class ConsentResource(DirectServeResource): ) if not compare_digest(want_mac, userhmac): - raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect") + raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 3689777266..595849f9d5 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -16,8 +16,7 @@ import logging import os - -from six.moves import urllib +import urllib from twisted.internet import defer from twisted.protocols.basic import FileSender diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 683a79c966..79cb0dddbe 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -17,9 +17,6 @@ import contextlib import logging import os import shutil -import sys - -import six from twisted.internet import defer from twisted.protocols.basic import FileSender @@ -117,12 +114,11 @@ class MediaStorage(object): with open(fname, "wb") as f: yield f, fname, finish except Exception: - t, v, tb = sys.exc_info() try: os.remove(fname) except Exception: pass - six.reraise(t, v, tb) + raise if not finished_called: raise Exception("Finished callback not called") diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index f206605727..f67e0fb3ec 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -24,10 +24,7 @@ import shutil import sys import traceback from typing import Dict, Optional - -import six -from six import string_types -from six.moves import urllib_parse as urlparse +from urllib import parse as urlparse from canonicaljson import json @@ -188,7 +185,7 @@ class PreviewUrlResource(DirectServeResource): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. og = cache_result["og"] - if isinstance(og, six.text_type): + if isinstance(og, str): og = og.encode("utf8") return og @@ -631,7 +628,7 @@ def _iterate_over_text(tree, *tags_to_ignore): if el is None: return - if isinstance(el, string_types): + if isinstance(el, str): yield el elif el.tag not in tags_to_ignore: # el.text is the text before the first child, so we can immediately diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index e7e8b8e688..3bfc8d7278 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from six import string_types - from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder from synapse.config import ConfigError @@ -118,7 +116,7 @@ def copy_with_str_subst(x, substitutions): Returns: copy of x """ - if isinstance(x, string_types): + if isinstance(x, str): return x % substitutions if isinstance(x, dict): return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()} diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 24ce8c4330..a6bb3221ff 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -14,10 +14,9 @@ # limitations under the License. import itertools import logging +from queue import Empty, PriorityQueue from typing import Dict, List, Optional, Set, Tuple -from six.moves.queue import Empty, PriorityQueue - from twisted.internet import defer from synapse.api.errors import StoreError diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 8a13101f1d..cfd24d2f06 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -21,9 +21,6 @@ from collections import OrderedDict, namedtuple from functools import wraps from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple -from six import integer_types, text_type -from six.moves import range - import attr from canonicaljson import json from prometheus_client import Counter @@ -893,8 +890,7 @@ class PersistEventsStore: "received_ts": self._clock.time_msec(), "sender": event.sender, "contains_url": ( - "url" in event.content - and isinstance(event.content["url"], text_type) + "url" in event.content and isinstance(event.content["url"], str) ), } for event, _ in events_and_contexts @@ -1345,10 +1341,10 @@ class PersistEventsStore: ): if ( "min_lifetime" in event.content - and not isinstance(event.content.get("min_lifetime"), integer_types) + and not isinstance(event.content.get("min_lifetime"), int) ) or ( "max_lifetime" in event.content - and not isinstance(event.content.get("max_lifetime"), integer_types) + and not isinstance(event.content.get("max_lifetime"), int) ): # Ignore the event if one of the value isn't an integer. return diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index f54c8b1ee0..62d28f44dc 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -15,8 +15,6 @@ import logging -from six import text_type - from canonicaljson import json from twisted.internet import defer @@ -133,7 +131,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): contains_url = "url" in content if contains_url: - contains_url &= isinstance(content["url"], text_type) + contains_url &= isinstance(content["url"], str) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. diff --git a/synapse/storage/data_stores/main/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py index 9b95411fb6..b42c02710a 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/as_users.py +++ b/synapse/storage/data_stores/main/schema/delta/30/as_users.py @@ -13,8 +13,6 @@ # limitations under the License. import logging -from six.moves import range - from synapse.config.appservice import load_appservices logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 13f49d8060..a8381dc577 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -17,8 +17,6 @@ import logging import re from collections import namedtuple -from six import string_types - from canonicaljson import json from twisted.internet import defer @@ -180,7 +178,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # skip over it. continue - if not isinstance(value, string_types): + if not isinstance(value, str): # If the event body, name or topic isn't a string # then skip over it continue diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index e89f0bffb5..379d758b5d 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -40,8 +40,6 @@ import abc import logging from collections import namedtuple -from six.moves import range - from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index 4219018302..f8c776be3f 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -16,8 +16,6 @@ import logging -from six.moves import range - from canonicaljson import json from twisted.internet import defer diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index b720212e55..5db9f20135 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -17,8 +17,6 @@ import logging from collections import namedtuple from typing import Dict, Iterable, List, Set, Tuple -from six.moves import range - from twisted.internet import defer from synapse.api.constants import EventTypes diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 645a70934c..3be20c866a 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -16,6 +16,7 @@ # limitations under the License. import logging import time +from sys import intern from time import monotonic as monotonic_time from typing import ( Any, @@ -29,8 +30,6 @@ from typing import ( TypeVar, ) -from six.moves import intern, range - from prometheus_client import Histogram from twisted.enterprise import adbapi diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 92dfd709bc..ec894a91cb 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -20,8 +20,6 @@ import logging from collections import deque, namedtuple from typing import Iterable, List, Optional, Set, Tuple -from six.moves import range - from prometheus_client import Counter, Histogram from twisted.internet import defer diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index f7af2bca7f..df42486351 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -19,8 +19,6 @@ import logging from contextlib import contextmanager from typing import Dict, Sequence, Set, Union -from six.moves import range - import attr from twisted.internet import defer diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 2a161bf244..c541bf4579 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -17,8 +17,6 @@ import logging import math from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union -from six import integer_types - from sortedcontainers import SortedDict from synapse.types import Collection @@ -88,7 +86,7 @@ class StreamChangeCache: def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ - assert type(stream_pos) in integer_types + assert isinstance(stream_pos, int) if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 8b17d1c8b8..6a3f6177b1 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six.moves import queue +import queue from twisted.internet import threads diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 9815bb8667..eab78dd256 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import binary_type, text_type - from canonicaljson import json from frozendict import frozendict @@ -26,7 +24,7 @@ def freeze(o): if isinstance(o, frozendict): return o - if isinstance(o, (binary_type, text_type)): + if isinstance(o, (bytes, str)): return o try: @@ -41,7 +39,7 @@ def unfreeze(o): if isinstance(o, (dict, frozendict)): return dict({k: unfreeze(v) for k, v in o.items()}) - if isinstance(o, (binary_type, text_type)): + if isinstance(o, (bytes, str)): return o try: diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 9bf6a44f75..023beb5ede 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six.moves import range - class _Entry(object): __slots__ = ["end_key", "queue"] diff --git a/synapse/visibility.py b/synapse/visibility.py index 780927cda1..3dfd4af26c 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -16,8 +16,6 @@ import logging import operator -from six.moves import map - from twisted.internet import defer from synapse.api.constants import EventTypes, Membership diff --git a/synctl b/synctl index 960fd357ee..ca398b84bd 100755 --- a/synctl +++ b/synctl @@ -26,8 +26,6 @@ import subprocess import sys import time -from six import iteritems - import yaml from synapse.config import find_config_files @@ -251,7 +249,7 @@ def main(): os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) cache_factors = config.get("synctl_cache_factors", {}) - for cache_name, factor in iteritems(cache_factors): + for cache_name, factor in cache_factors.items(): os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) worker_configfiles = [] @@ -362,7 +360,7 @@ def main(): if worker.cache_factor: os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor) - for cache_name, factor in iteritems(worker.cache_factors): + for cache_name, factor in worker.cache_factors.items(): os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) if not start_worker(worker.app, configfile, worker.configfile): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 4886bbb401..5ccda8b2bd 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -19,9 +19,9 @@ """Tests REST events for /rooms paths.""" import json +from urllib import parse as urlparse from mock import Mock -from six.moves.urllib import parse as urlparse from twisted.internet import defer diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index c7e5859970..fd641a7c2f 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -15,8 +15,7 @@ import itertools import json - -import six +import urllib from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -134,7 +133,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # Make sure next_batch has something in it that looks like it could be a # valid token. self.assertIsInstance( - channel.json_body.get("next_batch"), six.string_types, channel.json_body + channel.json_body.get("next_batch"), str, channel.json_body ) def test_repeated_paginate_relations(self): @@ -278,7 +277,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): prev_token = None found_event_ids = [] - encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8")) + encoded_key = urllib.parse.quote_plus("👍".encode("utf-8")) for _ in range(20): from_token = "" if prev_token: @@ -670,7 +669,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): query = "" if key: - query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) + query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) original_id = parent_id if parent_id else self.parent_id diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 1ca648ef2b..aefe648bdb 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -20,9 +20,9 @@ import tempfile from binascii import unhexlify from io import BytesIO from typing import Optional +from urllib import parse from mock import Mock -from six.moves.urllib import parse import attr import PIL.Image as Image diff --git a/tests/server.py b/tests/server.py index 1644710aa0..a5e57c52fa 100644 --- a/tests/server.py +++ b/tests/server.py @@ -2,8 +2,6 @@ import json import logging from io import BytesIO -from six import text_type - import attr from zope.interface import implementer @@ -174,7 +172,7 @@ def make_request( if not path.startswith(b"/"): path = b"/" + path - if isinstance(content, text_type): + if isinstance(content, str): content = content.encode("utf8") site = FakeSite() diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index a44960203e..cdc347bc53 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -15,8 +15,6 @@ import itertools -from six.moves import zip - import attr from synapse.api.constants import EventTypes, JoinRules, Membership diff --git a/tests/test_server.py b/tests/test_server.py index adae3c6e08..3f6f468e5b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -14,8 +14,7 @@ import logging import re - -from six import StringIO +from io import StringIO from twisted.internet.defer import Deferred from twisted.python.failure import Failure diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 5c2817cf28..b89798336c 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -14,7 +14,6 @@ import json -import six from mock import Mock from twisted.test.proto_helpers import MemoryReactorClock @@ -60,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"401", channel.result) self.assertTrue(channel.json_body is not None) - self.assertIsInstance(channel.json_body["session"], six.text_type) + self.assertIsInstance(channel.json_body["session"], str) self.assertIsInstance(channel.json_body["flows"], list) for flow in channel.json_body["flows"]: @@ -125,6 +124,6 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) self.assertTrue(channel.json_body is not None) - self.assertIsInstance(channel.json_body["user_id"], six.text_type) - self.assertIsInstance(channel.json_body["access_token"], six.text_type) - self.assertIsInstance(channel.json_body["device_id"], six.text_type) + self.assertIsInstance(channel.json_body["user_id"], str) + self.assertIsInstance(channel.json_body["access_token"], str) + self.assertIsInstance(channel.json_body["device_id"], str) diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index e90e08d1c0..8d6627ec33 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -15,9 +15,9 @@ import threading +from io import StringIO from mock import NonCallableMock -from six import StringIO from twisted.internet import defer, reactor diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index ca3858b184..0e52811948 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six.moves import range - from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError diff --git a/tests/utils.py b/tests/utils.py index 7ba8a31ff3..4d17355a5c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,9 +21,9 @@ import time import uuid import warnings from inspect import getcallargs +from urllib import parse as urlparse from mock import Mock, patch -from six.moves.urllib import parse as urlparse from twisted.internet import defer, reactor -- cgit 1.5.1 From ac51bd581aa98b8972d785a898d6233def9b636a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 16 Jun 2020 10:43:29 -0400 Subject: Include a user agent in federation requests. (#7677) --- changelog.d/7677.bugfix | 1 + synapse/http/federation/matrix_federation_agent.py | 10 +++++++++- synapse/http/federation/well_known_resolver.py | 17 +++++++++++++++-- synapse/http/matrixfederationclient.py | 9 ++++++++- tests/http/federation/test_matrix_federation_agent.py | 10 ++++++++++ 5 files changed, 43 insertions(+), 4 deletions(-) create mode 100644 changelog.d/7677.bugfix (limited to 'tests') diff --git a/changelog.d/7677.bugfix b/changelog.d/7677.bugfix new file mode 100644 index 0000000000..b63f041096 --- /dev/null +++ b/changelog.d/7677.bugfix @@ -0,0 +1 @@ +Include a user-agent for federation and well-known requests. diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index f5f917f5ae..c5fc746f2f 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -48,6 +48,9 @@ class MatrixFederationAgent(object): tls_client_options_factory (FederationPolicyForHTTPS|None): factory to use for fetching client tls options, or none to disable TLS. + user_agent (bytes): + The user agent header to use for federation requests. + _srv_resolver (SrvResolver|None): SRVResolver impl to use for looking up SRV records. None to use a default implementation. @@ -61,6 +64,7 @@ class MatrixFederationAgent(object): self, reactor, tls_client_options_factory, + user_agent, _srv_resolver=None, _well_known_resolver=None, ): @@ -78,6 +82,7 @@ class MatrixFederationAgent(object): ), pool=self._pool, ) + self.user_agent = user_agent if _well_known_resolver is None: _well_known_resolver = WellKnownResolver( @@ -87,6 +92,7 @@ class MatrixFederationAgent(object): pool=self._pool, contextFactory=tls_client_options_factory, ), + user_agent=self.user_agent, ) self._well_known_resolver = _well_known_resolver @@ -149,7 +155,7 @@ class MatrixFederationAgent(object): parsed_uri = urllib.parse.urlparse(uri) # We need to make sure the host header is set to the netloc of the - # server. + # server and that a user-agent is provided. if headers is None: headers = Headers() else: @@ -157,6 +163,8 @@ class MatrixFederationAgent(object): if not headers.hasHeader(b"host"): headers.addRawHeader(b"host", parsed_uri.netloc) + if not headers.hasHeader(b"user-agent"): + headers.addRawHeader(b"user-agent", self.user_agent) res = yield make_deferred_yieldable( self._agent.request(method, uri, headers, bodyProducer) diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 7ddfad286d..89a3b041ce 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -23,6 +23,7 @@ import attr from twisted.internet import defer from twisted.web.client import RedirectAgent, readBody from twisted.web.http import stringToDatetime +from twisted.web.http_headers import Headers from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock @@ -78,7 +79,12 @@ class WellKnownResolver(object): """ def __init__( - self, reactor, agent, well_known_cache=None, had_well_known_cache=None + self, + reactor, + agent, + user_agent, + well_known_cache=None, + had_well_known_cache=None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -92,6 +98,7 @@ class WellKnownResolver(object): self._well_known_cache = well_known_cache self._had_valid_well_known_cache = had_well_known_cache self._well_known_agent = RedirectAgent(agent) + self.user_agent = user_agent @defer.inlineCallbacks def get_well_known(self, server_name): @@ -227,6 +234,10 @@ class WellKnownResolver(object): uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") + headers = { + b"User-Agent": [self.user_agent], + } + i = 0 while True: i += 1 @@ -234,7 +245,9 @@ class WellKnownResolver(object): logger.info("Fetching %s", uri_str) try: response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri) + self._well_known_agent.request( + b"GET", uri, headers=Headers(headers) + ) ) body = yield make_deferred_yieldable(readBody(response)) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 7b33b9f10a..18f6a8fd29 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -197,7 +197,14 @@ class MatrixFederationHttpClient(object): self.reactor = Reactor() - self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory) + user_agent = hs.version_string + if hs.config.user_agent_suffix: + user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) + user_agent = user_agent.encode("ascii") + + self.agent = MatrixFederationAgent( + self.reactor, tls_client_options_factory, user_agent + ) # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 562397cdda..954e059e76 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -86,6 +86,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.well_known_resolver = WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=self.tls_factory), + b"test-agent", well_known_cache=self.well_known_cache, had_well_known_cache=self.had_well_known_cache, ) @@ -93,6 +94,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=self.tls_factory, + user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) @@ -186,6 +188,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) request = well_known_server.requests[0] + self.assertEqual( + request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] + ) self._send_well_known_response(request, content, headers=response_headers) return well_known_server @@ -231,6 +236,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] + ) content = request.content.read() self.assertEqual(content, b"") @@ -719,10 +727,12 @@ class MatrixFederationAgentTests(unittest.TestCase): agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=tls_factory, + user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=tls_factory), + b"test-agent", well_known_cache=self.well_known_cache, had_well_known_cache=self.had_well_known_cache, ), -- cgit 1.5.1 From 434716e1d33ec7ba772177f6659263539d68603f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 17 Jun 2020 08:36:46 -0400 Subject: Fetch from the r0 media path instead of the unspecced v1. (#7714) --- changelog.d/7714.bugfix | 1 + synapse/rest/media/v1/media_repository.py | 8 ++++---- tests/rest/media/v1/test_media_storage.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 changelog.d/7714.bugfix (limited to 'tests') diff --git a/changelog.d/7714.bugfix b/changelog.d/7714.bugfix new file mode 100644 index 0000000000..78925d94d1 --- /dev/null +++ b/changelog.d/7714.bugfix @@ -0,0 +1 @@ +Synapse will now fetch media from the proper specified URL (using the r0 prefix instead of the unspecified v1). diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 4ee8c60257..45628c07b4 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -338,7 +338,7 @@ class MediaRepository(object): with self.media_storage.store_into_file(file_info) as (f, fname, finish): request_path = "/".join( - ("/_matrix/media/v1/download", server_name, media_id) + ("/_matrix/media/r0/download", server_name, media_id) ) try: length, headers = await self.client.get_file( @@ -703,7 +703,7 @@ class MediaRepositoryResource(Resource): Uploads are POSTed to a resource which returns a token which is used to GET the download:: - => POST /_matrix/media/v1/upload HTTP/1.1 + => POST /_matrix/media/r0/upload HTTP/1.1 Content-Type: Content-Length: @@ -714,7 +714,7 @@ class MediaRepositoryResource(Resource): { "content_uri": "mxc:///" } - => GET /_matrix/media/v1/download// HTTP/1.1 + => GET /_matrix/media/r0/download// HTTP/1.1 <= HTTP/1.1 200 OK Content-Type: @@ -725,7 +725,7 @@ class MediaRepositoryResource(Resource): Clients can get thumbnails by supplying a desired width and height and thumbnailing method:: - => GET /_matrix/media/v1/thumbnail/ + => GET /_matrix/media/r0/thumbnail/ /?width=&height=&method= HTTP/1.1 <= HTTP/1.1 200 OK diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index aefe648bdb..2ed9312d56 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -232,7 +232,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(len(self.fetches), 1) self.assertEqual(self.fetches[0][1], "example.com") self.assertEqual( - self.fetches[0][2], "/_matrix/media/v1/download/" + self.media_id + self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id ) self.assertEqual(self.fetches[0][3], {"allow_remote": "false"}) -- cgit 1.5.1 From 3630825612054f04ae9d625583d26db0a78fd3eb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 17 Jun 2020 10:37:59 -0400 Subject: Convert the typing handler to async/await. (#7679) --- changelog.d/7679.misc | 1 + synapse/handlers/typing.py | 29 +++++++++++------------------ tests/handlers/test_typing.py | 13 +++++++------ 3 files changed, 19 insertions(+), 24 deletions(-) create mode 100644 changelog.d/7679.misc (limited to 'tests') diff --git a/changelog.d/7679.misc b/changelog.d/7679.misc new file mode 100644 index 0000000000..7db94691a9 --- /dev/null +++ b/changelog.d/7679.misc @@ -0,0 +1 @@ +Convert typing handler to async/await. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 4330abb9f7..6c7abaa578 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -17,8 +17,6 @@ import logging from collections import namedtuple from typing import List, Tuple -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.logging.context import run_in_background from synapse.types import UserID, get_domain_from_id @@ -115,8 +113,7 @@ class TypingHandler(object): def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) - @defer.inlineCallbacks - def started_typing(self, target_user, auth_user, room_id, timeout): + async def started_typing(self, target_user, auth_user, room_id, timeout): target_user_id = target_user.to_string() auth_user_id = auth_user.to_string() @@ -126,7 +123,7 @@ class TypingHandler(object): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_user_in_room(room_id, target_user_id) + await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has started typing in %s", target_user_id, room_id) @@ -145,8 +142,7 @@ class TypingHandler(object): self._push_update(member=member, typing=True) - @defer.inlineCallbacks - def stopped_typing(self, target_user, auth_user, room_id): + async def stopped_typing(self, target_user, auth_user, room_id): target_user_id = target_user.to_string() auth_user_id = auth_user.to_string() @@ -156,7 +152,7 @@ class TypingHandler(object): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_user_in_room(room_id, target_user_id) + await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has stopped typing in %s", target_user_id, room_id) @@ -164,12 +160,11 @@ class TypingHandler(object): self._stopped_typing(member) - @defer.inlineCallbacks def user_left_room(self, user, room_id): user_id = user.to_string() if self.is_mine_id(user_id): member = RoomMember(room_id=room_id, user_id=user_id) - yield self._stopped_typing(member) + self._stopped_typing(member) def _stopped_typing(self, member): if member.user_id not in self._room_typing.get(member.room_id, set()): @@ -188,10 +183,9 @@ class TypingHandler(object): self._push_update_local(member=member, typing=typing) - @defer.inlineCallbacks - def _push_remote(self, member, typing): + async def _push_remote(self, member, typing): try: - users = yield self.state.get_current_users_in_room(member.room_id) + users = await self.state.get_current_users_in_room(member.room_id) self._member_last_federation_poke[member] = self.clock.time_msec() now = self.clock.time_msec() @@ -215,8 +209,7 @@ class TypingHandler(object): except Exception: logger.exception("Error pushing typing notif to remotes") - @defer.inlineCallbacks - def _recv_edu(self, origin, content): + async def _recv_edu(self, origin, content): room_id = content["room_id"] user_id = content["user_id"] @@ -231,7 +224,7 @@ class TypingHandler(object): ) return - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) domains = {get_domain_from_id(u) for u in users} if self.server_name in domains: @@ -330,7 +323,7 @@ class TypingNotificationEventSource(object): "content": {"user_ids": list(typing)}, } - def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events(self, from_key, room_ids, **kwargs): with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() @@ -344,7 +337,7 @@ class TypingNotificationEventSource(object): events.append(self._make_event_for(room_id)) - return defer.succeed((events, handler._latest_room_serial)) + return (events, handler._latest_room_serial) def get_current_key(self): return self.get_typing_handler()._latest_room_serial diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 2fa8d4739b..1e6a53bf7f 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -129,6 +129,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def check_user_in_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") + return defer.succeed(None) hs.get_auth().check_user_in_room = check_user_in_room @@ -138,7 +139,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room def get_current_users_in_room(room_id): - return {str(u) for u in self.room_members} + return defer.succeed({str(u) for u in self.room_members}) hs.get_state_handler().get_current_users_in_room = get_current_users_in_room @@ -163,7 +164,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf( + self.get_success( self.handler.started_typing( target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 ) @@ -190,7 +191,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] - self.successResultOf( + self.get_success( self.handler.started_typing( target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 ) @@ -265,7 +266,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf( + self.get_success( self.handler.stopped_typing( target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID ) @@ -305,7 +306,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf( + self.get_success( self.handler.started_typing( target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 ) @@ -344,7 +345,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): # SYN-230 - see if we can still set after timeout - self.successResultOf( + self.get_success( self.handler.started_typing( target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 ) -- cgit 1.5.1 From 95e41f368b19996872a1661d7066670fe65f1eba Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 22 Jun 2020 08:04:14 -0400 Subject: Allow local media to be marked as safe from being quarantined. (#7718) --- changelog.d/7718.feature | 1 + scripts/synapse_port_db | 1 + .../storage/data_stores/main/media_repository.py | 9 ++ synapse/storage/data_stores/main/room.py | 42 ++----- .../58/08_media_safe_from_quarantine.sql.postgres | 18 +++ .../58/08_media_safe_from_quarantine.sql.sqlite | 18 +++ tests/rest/admin/test_admin.py | 137 ++++++++++----------- 7 files changed, 119 insertions(+), 107 deletions(-) create mode 100644 changelog.d/7718.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres create mode 100644 synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite (limited to 'tests') diff --git a/changelog.d/7718.feature b/changelog.d/7718.feature new file mode 100644 index 0000000000..17071b9ea9 --- /dev/null +++ b/changelog.d/7718.feature @@ -0,0 +1 @@ +Media can now be marked as safe from quarantined. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 810e08beb5..c2023f3e4d 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = { "account_validity": ["email_sent"], "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], + "local_media_repository": ["safe_from_quarantine"], } diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 8aecd414c2..15bc13cbd0 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -81,6 +81,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_media", ) + def mark_local_media_as_safe(self, media_id: str): + """Mark a local media as safe from quarantining.""" + return self.db.simple_update_one( + table="local_media_repository", + keyvalues={"media_id": media_id}, + updatevalues={"safe_from_quarantine": True}, + desc="mark_local_media_as_safe", + ) + def get_url_cache(self, url, ts): """Get the media_id and ts for a cached URL as of the given timestamp Returns: diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 46f643c6b9..13e366536a 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -626,36 +626,10 @@ class RoomWorkerStore(SQLBaseStore): def _quarantine_media_in_room_txn(txn): local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) - total_media_quarantined = 0 - - # Now update all the tables to set the quarantined_by flag - - txn.executemany( - """ - UPDATE local_media_repository - SET quarantined_by = ? - WHERE media_id = ? - """, - ((quarantined_by, media_id) for media_id in local_mxcs), - ) - - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - ( - (quarantined_by, origin, media_id) - for origin, media_id in remote_mxcs - ), + return self._quarantine_media_txn( + txn, local_mxcs, remote_mxcs, quarantined_by ) - total_media_quarantined += len(local_mxcs) - total_media_quarantined += len(remote_mxcs) - - return total_media_quarantined - return self.db.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -805,17 +779,17 @@ class RoomWorkerStore(SQLBaseStore): Returns: The total number of media items quarantined """ - total_media_quarantined = 0 - # Update all the tables to set the quarantined_by flag txn.executemany( """ UPDATE local_media_repository SET quarantined_by = ? - WHERE media_id = ? + WHERE media_id = ? AND safe_from_quarantine = ? """, - ((quarantined_by, media_id) for media_id in local_mxcs), + ((quarantined_by, media_id, False) for media_id in local_mxcs), ) + # Note that a rowcount of -1 can be used to indicate no rows were affected. + total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 txn.executemany( """ @@ -825,9 +799,7 @@ class RoomWorkerStore(SQLBaseStore): """, ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), ) - - total_media_quarantined += len(local_mxcs) - total_media_quarantined += len(remote_mxcs) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 return total_media_quarantined diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres new file mode 100644 index 0000000000..597f2ffd3d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The local_media_repository should have files which do not get quarantined, +-- e.g. files from sticker packs. +ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite new file mode 100644 index 0000000000..69db89ac0e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The local_media_repository should have files which do not get quarantined, +-- e.g. files from sticker packs. +ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0; diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 977615ebef..b1a4decced 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -220,6 +220,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): return hs + def _ensure_quarantined(self, admin_user_tok, server_and_media_id): + """Ensure a piece of media is quarantined when trying to access it.""" + request, channel = self.make_request( + "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id + ), + ) + def test_quarantine_media_requires_admin(self): self.register_user("nonadmin", "pass", admin=False) non_admin_user_tok = self.login("nonadmin", "pass") @@ -292,24 +310,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Attempt to access the media - request, channel = self.make_request( - "GET", - server_name_and_media_id, - shorthand=False, - access_token=admin_user_tok, - ) - request.render(self.download_resource) - self.pump(1.0) - - # Should be quarantined - self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_name_and_media_id - ), - ) + self._ensure_quarantined(admin_user_tok, server_name_and_media_id) def test_quarantine_all_media_in_room(self, override_url_template=None): self.register_user("room_admin", "pass", admin=True) @@ -371,45 +372,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): server_and_media_id_2 = mxc_2[6:] # Test that we cannot download any of the media anymore - request, channel = self.make_request( - "GET", - server_and_media_id_1, - shorthand=False, - access_token=non_admin_user_tok, - ) - request.render(self.download_resource) - self.pump(1.0) - - # Should be quarantined - self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_and_media_id_1 - ), - ) - - request, channel = self.make_request( - "GET", - server_and_media_id_2, - shorthand=False, - access_token=non_admin_user_tok, - ) - request.render(self.download_resource) - self.pump(1.0) - - # Should be quarantined - self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_and_media_id_2 - ), - ) + self._ensure_quarantined(admin_user_tok, server_and_media_id_1) + self._ensure_quarantined(admin_user_tok, server_and_media_id_2) - def test_quaraantine_all_media_in_room_deprecated_api_path(self): + def test_quarantine_all_media_in_room_deprecated_api_path(self): # Perform the above test with the deprecated API path self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s") @@ -449,25 +415,52 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Attempt to access each piece of media + self._ensure_quarantined(admin_user_tok, server_and_media_id_1) + self._ensure_quarantined(admin_user_tok, server_and_media_id_2) + + def test_cannot_quarantine_safe_media(self): + self.register_user("user_admin", "pass", admin=True) + admin_user_tok = self.login("user_admin", "pass") + + non_admin_user = self.register_user("user_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("user_nonadmin", "pass") + + # Upload some media + response_1 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + response_2 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + + # Extract media IDs + server_and_media_id_1 = response_1["content_uri"][6:] + server_and_media_id_2 = response_2["content_uri"][6:] + + # Mark the second item as safe from quarantine. + _, media_id_2 = server_and_media_id_2.split("/") + self.get_success(self.store.mark_local_media_as_safe(media_id_2)) + + # Quarantine all media by this user + url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( + non_admin_user + ) request, channel = self.make_request( - "GET", - server_and_media_id_1, - shorthand=False, - access_token=non_admin_user_tok, + "POST", url.encode("ascii"), access_token=admin_user_tok, ) - request.render(self.download_resource) + self.render(request) self.pump(1.0) - - # Should be quarantined + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_and_media_id_1, - ), + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 1}, + "Expected 1 quarantined item", ) + # Attempt to access each piece of media, the first should fail, the + # second should succeed. + self._ensure_quarantined(admin_user_tok, server_and_media_id_1) + # Attempt to access each piece of media request, channel = self.make_request( "GET", @@ -478,12 +471,12 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): request.render(self.download_resource) self.pump(1.0) - # Should be quarantined + # Shouldn't be quarantined self.assertEqual( - 404, + 200, int(channel.code), msg=( - "Expected to receive a 404 on accessing quarantined media: %s" + "Expected to receive a 200 on accessing not-quarantined media: %s" % server_and_media_id_2 ), ) -- cgit 1.5.1 From 6920e58136671f086536332bdd6844dff0d4b429 Mon Sep 17 00:00:00 2001 From: Sorunome Date: Wed, 24 Jun 2020 11:23:55 +0200 Subject: add org.matrix.login.jwt so that m.login.jwt can be deprecated (#7675) --- changelog.d/7675.removal | 1 + synapse/rest/client/v1/login.py | 5 ++++- tests/rest/client/v1/test_login.py | 10 +++++++--- 3 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 changelog.d/7675.removal (limited to 'tests') diff --git a/changelog.d/7675.removal b/changelog.d/7675.removal new file mode 100644 index 0000000000..2500e2c578 --- /dev/null +++ b/changelog.d/7675.removal @@ -0,0 +1 @@ +Deprecate `m.login.jwt` login method in favour of `org.matrix.login.jwt`, as `m.login.jwt` is not part of the Matrix spec. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index c2c9a9c3aa..bf0f9bd077 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -81,7 +81,8 @@ class LoginRestServlet(RestServlet): CAS_TYPE = "m.login.cas" SSO_TYPE = "m.login.sso" TOKEN_TYPE = "m.login.token" - JWT_TYPE = "m.login.jwt" + JWT_TYPE = "org.matrix.login.jwt" + JWT_TYPE_DEPRECATED = "m.login.jwt" def __init__(self, hs): super(LoginRestServlet, self).__init__() @@ -116,6 +117,7 @@ class LoginRestServlet(RestServlet): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) + flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) if self.cas_enabled: # we advertise CAS for backwards compat, though MSC1721 renamed it @@ -149,6 +151,7 @@ class LoginRestServlet(RestServlet): try: if self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE + or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): result = await self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 9033f09fd2..fd97999956 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -526,7 +526,9 @@ class JWTTestCase(unittest.HomeserverTestCase): return jwt.encode(token, secret, "HS256").decode("ascii") def jwt_login(self, *args): - params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + params = json.dumps( + {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + ) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) return channel @@ -568,7 +570,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["error"], "Invalid JWT") def test_login_no_token(self): - params = json.dumps({"type": "m.login.jwt"}) + params = json.dumps({"type": "org.matrix.login.jwt"}) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) self.assertEqual(channel.result["code"], b"401", channel.result) @@ -640,7 +642,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return jwt.encode(token, secret, "RS256").decode("ascii") def jwt_login(self, *args): - params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + params = json.dumps( + {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + ) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) return channel -- cgit 1.5.1 From 0e0a2817a29391fd777f7ee683dc03d63cf40302 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 24 Jun 2020 18:48:18 +0100 Subject: Yield during large v2 state res. (#7735) State res v2 across large data sets can be very CPU intensive, and if all the relevant events are in the cache the algorithm will run from start to finish within a single reactor tick. This can result in blocking the reactor tick for several seconds, which can have major repercussions on other requests. To fix this we simply add the occaisonal `sleep(0)` during iterations to yield execution until the next reactor tick. The aim is to only do this for large data sets so that we don't impact otherwise quick resolutions.= --- changelog.d/7735.bugfix | 1 + synapse/handlers/federation.py | 1 + synapse/state/__init__.py | 6 ++++- synapse/state/v2.py | 56 ++++++++++++++++++++++++++++++++++-------- tests/state/test_v2.py | 9 +++++++ 5 files changed, 62 insertions(+), 11 deletions(-) create mode 100644 changelog.d/7735.bugfix (limited to 'tests') diff --git a/changelog.d/7735.bugfix b/changelog.d/7735.bugfix new file mode 100644 index 0000000000..86959a5ca4 --- /dev/null +++ b/changelog.d/7735.bugfix @@ -0,0 +1 @@ +Fix large state resolutions from stalling Synapse for seconds at a time. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 873f6bc39f..3828ff0ef0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -376,6 +376,7 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version_id(room_id) state_map = await resolve_events_with_store( + self.clock, room_id, room_version, state_maps, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 50fd843f66..495d9f04c8 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -32,6 +32,7 @@ from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import StateMap +from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func @@ -414,6 +415,7 @@ class StateHandler(object): with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + self.clock, event.room_id, room_version, state_set_ids, @@ -516,6 +518,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + self.clock, room_id, room_version, list(state_groups_ids.values()), @@ -589,6 +592,7 @@ def _make_state_cache_entry(new_state, state_groups_ids): def resolve_events_with_store( + clock: Clock, room_id: str, room_version: str, state_sets: List[StateMap[str]], @@ -625,7 +629,7 @@ def resolve_events_with_store( ) else: return v2.resolve_events_with_store( - room_id, room_version, state_sets, event_map, state_res_store + clock, room_id, room_version, state_sets, event_map, state_res_store ) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 57eadce4e6..7181ecda9a 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -27,12 +27,20 @@ from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.types import StateMap +from synapse.util import Clock logger = logging.getLogger(__name__) +# We want to yield to the reactor occasionally during state res when dealing +# with large data sets, so that we don't exhaust the reactor. This is done by +# yielding to reactor during loops every N iterations. +_YIELD_AFTER_ITERATIONS = 100 + + @defer.inlineCallbacks def resolve_events_with_store( + clock: Clock, room_id: str, room_version: str, state_sets: List[StateMap[str]], @@ -42,13 +50,11 @@ def resolve_events_with_store( """Resolves the state using the v2 state resolution algorithm Args: + clock room_id: the room we are working in - room_version: The room version - state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be @@ -113,7 +119,7 @@ def resolve_events_with_store( ) sorted_power_events = yield _reverse_topological_power_sort( - room_id, power_events, event_map, state_res_store, full_conflicted_set + clock, room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) @@ -142,7 +148,7 @@ def resolve_events_with_store( pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - room_id, leftover_events, pl, event_map, state_res_store + clock, room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") @@ -317,12 +323,13 @@ def _add_event_and_auth_chain_to_graph( @defer.inlineCallbacks def _reverse_topological_power_sort( - room_id, event_ids, event_map, state_res_store, auth_diff + clock, room_id, event_ids, event_map, state_res_store, auth_diff ): """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: + clock (Clock) room_id (str): the room we are working in event_ids (list[str]): The events to sort event_map (dict[str,FrozenEvent]) @@ -334,18 +341,28 @@ def _reverse_topological_power_sort( """ graph = {} - for event_id in event_ids: + for idx, event_id in enumerate(event_ids, start=1): yield _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ) + # We yield occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx % _YIELD_AFTER_ITERATIONS == 0: + yield clock.sleep(0) + event_to_pl = {} - for event_id in graph: + for idx, event_id in enumerate(graph, start=1): pl = yield _get_power_level_for_sender( room_id, event_id, event_map, state_res_store ) event_to_pl[event_id] = pl + # We yield occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx % _YIELD_AFTER_ITERATIONS == 0: + yield clock.sleep(0) + def _get_power_order(event_id): ev = event_map[event_id] pl = event_to_pl[event_id] @@ -423,12 +440,13 @@ def _iterative_auth_checks( @defer.inlineCallbacks def _mainline_sort( - room_id, event_ids, resolved_power_event_id, event_map, state_res_store + clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store ): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: + clock (Clock) room_id (str): room we're working in event_ids (list[str]): Events to sort resolved_power_event_id (str): The final resolved power level event ID @@ -438,8 +456,14 @@ def _mainline_sort( Returns: Deferred[list[str]]: The sorted list """ + if not event_ids: + # It's possible for there to be no event IDs here to sort, so we can + # skip calculating the mainline in that case. + return [] + mainline = [] pl = resolved_power_event_id + idx = 0 while pl: mainline.append(pl) pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) @@ -453,17 +477,29 @@ def _mainline_sort( pl = aid break + # We yield occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0: + yield clock.sleep(0) + + idx += 1 + mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))} event_ids = list(event_ids) order_map = {} - for ev_id in event_ids: + for idx, ev_id in enumerate(event_ids, start=1): depth = yield _get_mainline_depth_for_event( event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) + # We yield occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx % _YIELD_AFTER_ITERATIONS == 0: + yield clock.sleep(0) + event_ids.sort(key=lambda ev_id: order_map[ev_id]) return event_ids diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index cdc347bc53..38f9b423ef 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -17,6 +17,8 @@ import itertools import attr +from twisted.internet import defer + from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event @@ -41,6 +43,11 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN} ORIGIN_SERVER_TS = 0 +class FakeClock: + def sleep(self, msec): + return defer.succeed(None) + + class FakeEvent(object): """A fake event we use as a convenience. @@ -417,6 +424,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( + FakeClock(), ROOM_ID, RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], @@ -565,6 +573,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( + FakeClock(), ROOM_ID, RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], -- cgit 1.5.1 From 71cccf1593bd73a1baef87483117b9be9a99b837 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 30 Jun 2020 15:41:36 -0400 Subject: Additional configuration options for auto-join rooms (#7763) --- changelog.d/7763.feature | 1 + docs/sample_config.yaml | 60 ++++++++++- synapse/config/registration.py | 106 +++++++++++++++++- synapse/handlers/register.py | 230 +++++++++++++++++++++++++++++----------- synapse/rest/admin/rooms.py | 4 +- tests/handlers/test_register.py | 212 +++++++++++++++++++++++++++++++++++- 6 files changed, 542 insertions(+), 71 deletions(-) create mode 100644 changelog.d/7763.feature (limited to 'tests') diff --git a/changelog.d/7763.feature b/changelog.d/7763.feature new file mode 100644 index 0000000000..4a7563dad3 --- /dev/null +++ b/changelog.d/7763.feature @@ -0,0 +1 @@ +Expand the configuration options for auto-join rooms. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 05e7bf215a..2d27b0b34d 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1210,7 +1210,11 @@ account_threepid_delegates: #enable_3pid_changes: false # Users who register on this homeserver will automatically be joined -# to these rooms +# to these rooms. +# +# By default, any room aliases included in this list will be created +# as a publicly joinable room when the first user registers for the +# homeserver. This behaviour can be customised with the settings below. # #auto_join_rooms: # - "#example:example.com" @@ -1218,10 +1222,62 @@ account_threepid_delegates: # Where auto_join_rooms are specified, setting this flag ensures that the # the rooms exist by creating them when the first user on the # homeserver registers. +# +# By default the auto-created rooms are publicly joinable from any federated +# server. Use the autocreate_auto_join_rooms_federated and +# autocreate_auto_join_room_preset settings below to customise this behaviour. +# # Setting to false means that if the rooms are not manually created, # users cannot be auto-joined since they do not exist. # -#autocreate_auto_join_rooms: true +# Defaults to true. Uncomment the following line to disable automatically +# creating auto-join rooms. +# +#autocreate_auto_join_rooms: false + +# Whether the auto_join_rooms that are auto-created are available via +# federation. Only has an effect if autocreate_auto_join_rooms is true. +# +# Note that whether a room is federated cannot be modified after +# creation. +# +# Defaults to true: the room will be joinable from other servers. +# Uncomment the following to prevent users from other homeservers from +# joining these rooms. +# +#autocreate_auto_join_rooms_federated: false + +# The room preset to use when auto-creating one of auto_join_rooms. Only has an +# effect if autocreate_auto_join_rooms is true. +# +# This can be one of "public_chat", "private_chat", or "trusted_private_chat". +# If a value of "private_chat" or "trusted_private_chat" is used then +# auto_join_mxid_localpart must also be configured. +# +# Defaults to "public_chat", meaning that the room is joinable by anyone, including +# federated servers if autocreate_auto_join_rooms_federated is true (the default). +# Uncomment the following to require an invitation to join these rooms. +# +#autocreate_auto_join_room_preset: private_chat + +# The local part of the user id which is used to create auto_join_rooms if +# autocreate_auto_join_rooms is true. If this is not provided then the +# initial user account that registers will be used to create the rooms. +# +# The user id is also used to invite new users to any auto-join rooms which +# are set to invite-only. +# +# It *must* be configured if autocreate_auto_join_room_preset is set to +# "private_chat" or "trusted_private_chat". +# +# Note that this must be specified in order for new users to be correctly +# invited to any auto-join rooms which have been set to invite-only (either +# at the time of creation or subsequently). +# +# Note that, if the room already exists, this user must be joined and +# have the appropriate permissions to invite new members. +# +#auto_join_mxid_localpart: system # When auto_join_rooms is specified, setting this flag to false prevents # guest accounts from being automatically joined to the rooms. diff --git a/synapse/config/registration.py b/synapse/config/registration.py index fecced2d57..6badf4e75d 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -18,8 +18,9 @@ from distutils.util import strtobool import pkg_resources +from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError -from synapse.types import RoomAlias +from synapse.types import RoomAlias, UserID from synapse.util.stringutils import random_string_with_symbols @@ -127,7 +128,50 @@ class RegistrationConfig(Config): for room_alias in self.auto_join_rooms: if not RoomAlias.is_valid(room_alias): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) + + # Options for creating auto-join rooms if they do not exist yet. self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.autocreate_auto_join_rooms_federated = config.get( + "autocreate_auto_join_rooms_federated", True + ) + self.autocreate_auto_join_room_preset = ( + config.get("autocreate_auto_join_room_preset") + or RoomCreationPreset.PUBLIC_CHAT + ) + self.auto_join_room_requires_invite = self.autocreate_auto_join_room_preset in { + RoomCreationPreset.PRIVATE_CHAT, + RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + } + + # Pull the creater/inviter from the configuration, this gets used to + # send invites for invite-only rooms. + mxid_localpart = config.get("auto_join_mxid_localpart") + self.auto_join_user_id = None + if mxid_localpart: + # Convert the localpart to a full mxid. + self.auto_join_user_id = UserID( + mxid_localpart, self.server_name + ).to_string() + + if self.autocreate_auto_join_rooms: + # Ensure the preset is a known value. + if self.autocreate_auto_join_room_preset not in { + RoomCreationPreset.PUBLIC_CHAT, + RoomCreationPreset.PRIVATE_CHAT, + RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + }: + raise ConfigError("Invalid value for autocreate_auto_join_room_preset") + # If the preset requires invitations to be sent, ensure there's a + # configured user to send them from. + if self.auto_join_room_requires_invite: + if not mxid_localpart: + raise ConfigError( + "The configuration option `auto_join_mxid_localpart` is required if " + "`autocreate_auto_join_room_preset` is set to private_chat or trusted_private_chat, such that " + "Synapse knows who to send invitations from. Please " + "configure `auto_join_mxid_localpart`." + ) + self.auto_join_rooms_for_guests = config.get("auto_join_rooms_for_guests", True) self.enable_set_displayname = config.get("enable_set_displayname", True) @@ -357,7 +401,11 @@ class RegistrationConfig(Config): #enable_3pid_changes: false # Users who register on this homeserver will automatically be joined - # to these rooms + # to these rooms. + # + # By default, any room aliases included in this list will be created + # as a publicly joinable room when the first user registers for the + # homeserver. This behaviour can be customised with the settings below. # #auto_join_rooms: # - "#example:example.com" @@ -365,10 +413,62 @@ class RegistrationConfig(Config): # Where auto_join_rooms are specified, setting this flag ensures that the # the rooms exist by creating them when the first user on the # homeserver registers. + # + # By default the auto-created rooms are publicly joinable from any federated + # server. Use the autocreate_auto_join_rooms_federated and + # autocreate_auto_join_room_preset settings below to customise this behaviour. + # # Setting to false means that if the rooms are not manually created, # users cannot be auto-joined since they do not exist. # - #autocreate_auto_join_rooms: true + # Defaults to true. Uncomment the following line to disable automatically + # creating auto-join rooms. + # + #autocreate_auto_join_rooms: false + + # Whether the auto_join_rooms that are auto-created are available via + # federation. Only has an effect if autocreate_auto_join_rooms is true. + # + # Note that whether a room is federated cannot be modified after + # creation. + # + # Defaults to true: the room will be joinable from other servers. + # Uncomment the following to prevent users from other homeservers from + # joining these rooms. + # + #autocreate_auto_join_rooms_federated: false + + # The room preset to use when auto-creating one of auto_join_rooms. Only has an + # effect if autocreate_auto_join_rooms is true. + # + # This can be one of "public_chat", "private_chat", or "trusted_private_chat". + # If a value of "private_chat" or "trusted_private_chat" is used then + # auto_join_mxid_localpart must also be configured. + # + # Defaults to "public_chat", meaning that the room is joinable by anyone, including + # federated servers if autocreate_auto_join_rooms_federated is true (the default). + # Uncomment the following to require an invitation to join these rooms. + # + #autocreate_auto_join_room_preset: private_chat + + # The local part of the user id which is used to create auto_join_rooms if + # autocreate_auto_join_rooms is true. If this is not provided then the + # initial user account that registers will be used to create the rooms. + # + # The user id is also used to invite new users to any auto-join rooms which + # are set to invite-only. + # + # It *must* be configured if autocreate_auto_join_room_preset is set to + # "private_chat" or "trusted_private_chat". + # + # Note that this must be specified in order for new users to be correctly + # invited to any auto-join rooms which have been set to invite-only (either + # at the time of creation or subsequently). + # + # Note that, if the room already exists, this user must be joined and + # have the appropriate permissions to invite new members. + # + #auto_join_mxid_localpart: system # When auto_join_rooms is specified, setting this flag to false prevents # guest accounts from being automatically joined to the rooms. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 51979ea43e..78c3772ac1 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -17,7 +17,7 @@ import logging from synapse import types -from synapse.api.constants import MAX_USERID_LENGTH, LoginType +from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict @@ -26,7 +26,8 @@ from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, ReplicationRegisterServlet, ) -from synapse.types import RoomAlias, RoomID, UserID, create_requester +from synapse.storage.state import StateFilter +from synapse.types import RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer from ._base import BaseHandler @@ -270,51 +271,157 @@ class RegistrationHandler(BaseHandler): return user_id - async def _auto_join_rooms(self, user_id): - """Automatically joins users to auto join rooms - creating the room in the first place - if the user is the first to be created. + async def _create_and_join_rooms(self, user_id: str): + """ + Create the auto-join rooms and join or invite the user to them. + + This should only be called when the first "real" user registers. Args: - user_id(str): The user to join + user_id: The user to join """ - # auto-join the user to any rooms we're supposed to dump them into - fake_requester = create_requester(user_id) + # Getting the handlers during init gives a dependency loop. + room_creation_handler = self.hs.get_room_creation_handler() + room_member_handler = self.hs.get_room_member_handler() - # try to create the room if we're the first real user on the server. Note - # that an auto-generated support or bot user is not a real user and will never be - # the user to create the room - should_auto_create_rooms = False - is_real_user = await self.store.is_real_user(user_id) - if self.hs.config.autocreate_auto_join_rooms and is_real_user: - count = await self.store.count_real_users() - should_auto_create_rooms = count == 1 - for r in self.hs.config.auto_join_rooms: + # Generate a stub for how the rooms will be configured. + stub_config = { + "preset": self.hs.config.registration.autocreate_auto_join_room_preset, + } + + # If the configuration providers a user ID to create rooms with, use + # that instead of the first user registered. + requires_join = False + if self.hs.config.registration.auto_join_user_id: + fake_requester = create_requester( + self.hs.config.registration.auto_join_user_id + ) + + # If the room requires an invite, add the user to the list of invites. + if self.hs.config.registration.auto_join_room_requires_invite: + stub_config["invite"] = [user_id] + + # If the room is being created by a different user, the first user + # registered needs to join it. Note that in the case of an invitation + # being necessary this will occur after the invite was sent. + requires_join = True + else: + fake_requester = create_requester(user_id) + + # Choose whether to federate the new room. + if not self.hs.config.registration.autocreate_auto_join_rooms_federated: + stub_config["creation_content"] = {"m.federate": False} + + for r in self.hs.config.registration.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) + try: - if should_auto_create_rooms: - room_alias = RoomAlias.from_string(r) - if self.hs.hostname != room_alias.domain: - logger.warning( - "Cannot create room alias %s, " - "it does not match server domain", - r, - ) - else: - # create room expects the localpart of the room alias - room_alias_localpart = room_alias.localpart - - # getting the RoomCreationHandler during init gives a dependency - # loop - await self.hs.get_room_creation_handler().create_room( - fake_requester, - config={ - "preset": "public_chat", - "room_alias_name": room_alias_localpart, - }, + room_alias = RoomAlias.from_string(r) + + if self.hs.hostname != room_alias.domain: + logger.warning( + "Cannot create room alias %s, " + "it does not match server domain", + r, + ) + else: + # A shallow copy is OK here since the only key that is + # modified is room_alias_name. + config = stub_config.copy() + # create room expects the localpart of the room alias + config["room_alias_name"] = room_alias.localpart + + info, _ = await room_creation_handler.create_room( + fake_requester, config=config, ratelimit=False, + ) + + # If the room does not require an invite, but another user + # created it, then ensure the first user joins it. + if requires_join: + await room_member_handler.update_membership( + requester=create_requester(user_id), + target=UserID.from_string(user_id), + room_id=info["room_id"], + # Since it was just created, there are no remote hosts. + remote_room_hosts=[], + action="join", ratelimit=False, ) + + except ConsentNotGivenError as e: + # Technically not necessary to pull out this error though + # moving away from bare excepts is a good thing to do. + logger.error("Failed to join new user to %r: %r", r, e) + except Exception as e: + logger.error("Failed to join new user to %r: %r", r, e) + + async def _join_rooms(self, user_id: str): + """ + Join or invite the user to the auto-join rooms. + + Args: + user_id: The user to join + """ + room_member_handler = self.hs.get_room_member_handler() + + for r in self.hs.config.registration.auto_join_rooms: + logger.info("Auto-joining %s to %s", user_id, r) + + try: + room_alias = RoomAlias.from_string(r) + + if RoomAlias.is_valid(r): + ( + room_id, + remote_room_hosts, + ) = await room_member_handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() else: - await self._join_user_to_room(fake_requester, r) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (r,) + ) + + # Calculate whether the room requires an invite or can be + # joined directly. Note that unless a join rule of public exists, + # it is treated as requiring an invite. + requires_invite = True + + state = await self.store.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) + ) + + event_id = state.get((EventTypes.JoinRules, "")) + if event_id: + join_rules_event = await self.store.get_event( + event_id, allow_none=True + ) + if join_rules_event: + join_rule = join_rules_event.content.get("join_rule", None) + requires_invite = join_rule and join_rule != JoinRules.PUBLIC + + # Send the invite, if necessary. + if requires_invite: + await room_member_handler.update_membership( + requester=create_requester( + self.hs.config.registration.auto_join_user_id + ), + target=UserID.from_string(user_id), + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="invite", + ratelimit=False, + ) + + # Send the join. + await room_member_handler.update_membership( + requester=create_requester(user_id), + target=UserID.from_string(user_id), + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="join", + ratelimit=False, + ) + except ConsentNotGivenError as e: # Technically not necessary to pull out this error though # moving away from bare excepts is a good thing to do. @@ -322,6 +429,29 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) + async def _auto_join_rooms(self, user_id: str): + """Automatically joins users to auto join rooms - creating the room in the first place + if the user is the first to be created. + + Args: + user_id: The user to join + """ + # auto-join the user to any rooms we're supposed to dump them into + + # try to create the room if we're the first real user on the server. Note + # that an auto-generated support or bot user is not a real user and will never be + # the user to create the room + should_auto_create_rooms = False + is_real_user = await self.store.is_real_user(user_id) + if self.hs.config.registration.autocreate_auto_join_rooms and is_real_user: + count = await self.store.count_real_users() + should_auto_create_rooms = count == 1 + + if should_auto_create_rooms: + await self._create_and_join_rooms(user_id) + else: + await self._join_rooms(user_id) + async def post_consent_actions(self, user_id): """A series of registration actions that can only be carried out once consent has been granted @@ -392,30 +522,6 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id += 1 return str(id) - async def _join_user_to_room(self, requester, room_identifier): - room_member_handler = self.hs.get_room_member_handler() - if RoomID.is_valid(room_identifier): - room_id = room_identifier - elif RoomAlias.is_valid(room_identifier): - room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = await room_member_handler.lookup_room_alias( - room_alias - ) - room_id = room_id.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) - - await room_member_handler.update_membership( - requester=requester, - target=requester.user, - room_id=room_id, - remote_room_hosts=remote_room_hosts, - action="join", - ratelimit=False, - ) - def check_registration_ratelimit(self, address): """A simple helper method to check whether the registration rate limit has been hit for a given IP address diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 8173baef8f..e07c32118d 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -15,7 +15,7 @@ import logging from typing import List, Optional -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -77,7 +77,7 @@ class ShutdownRoomRestServlet(RestServlet): info, stream_id = await self._room_creation_handler.create_room( room_creator_requester, config={ - "preset": "public_chat", + "preset": RoomCreationPreset.PUBLIC_CHAT, "name": room_name, "power_level_content_override": {"users_default": -10}, }, diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index ca32f993a3..6d45c4b233 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -22,6 +22,8 @@ from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler from synapse.types import RoomAlias, UserID, create_requester +from tests.unittest import override_config + from .. import unittest @@ -145,9 +147,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -193,9 +195,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias = RoomAlias.from_string(room_alias_str) self.get_failure(directory_handler.get_association(room_alias), SynapseError) + @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] self.store.count_real_users = Mock(return_value=defer.succeed(1)) self.store.is_real_user = Mock(return_value=defer.succeed(True)) @@ -218,6 +220,212 @@ class RegistrationTestCase(unittest.HomeserverTestCase): rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config( + { + "auto_join_rooms": ["#room:test"], + "autocreate_auto_join_rooms_federated": False, + } + ) + def test_auto_create_auto_join_rooms_federated(self): + """ + Auto-created rooms that are private require an invite to go to the user + (instead of directly joining it). + """ + room_alias_str = "#room:test" + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + + # Ensure the room was created. + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + # Ensure the room is properly not federated. + room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + self.assertFalse(room["federatable"]) + self.assertFalse(room["public"]) + self.assertEqual(room["join_rules"], "public") + self.assertIsNone(room["guest_access"]) + + # The user should be in the room. + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + @override_config( + {"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"} + ) + def test_auto_join_mxid_localpart(self): + """ + Ensure the user still needs up in the room created by a different user. + """ + # Ensure the support user exists. + inviter = "@support:test" + + room_alias_str = "#room:test" + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + + # Ensure the room was created. + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + # Ensure the room is properly a public room. + room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + self.assertEqual(room["join_rules"], "public") + + # Both users should be in the room. + rooms = self.get_success(self.store.get_rooms_for_user(inviter)) + self.assertIn(room_id["room_id"], rooms) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + # Register a second user, which should also end up in the room. + user_id = self.get_success(self.handler.register_user(localpart="bob")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + @override_config( + { + "auto_join_rooms": ["#room:test"], + "autocreate_auto_join_room_preset": "private_chat", + "auto_join_mxid_localpart": "support", + } + ) + def test_auto_create_auto_join_room_preset(self): + """ + Auto-created rooms that are private require an invite to go to the user + (instead of directly joining it). + """ + # Ensure the support user exists. + inviter = "@support:test" + + room_alias_str = "#room:test" + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + + # Ensure the room was created. + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + # Ensure the room is properly a private room. + room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + self.assertFalse(room["public"]) + self.assertEqual(room["join_rules"], "invite") + self.assertEqual(room["guest_access"], "can_join") + + # Both users should be in the room. + rooms = self.get_success(self.store.get_rooms_for_user(inviter)) + self.assertIn(room_id["room_id"], rooms) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + # Register a second user, which should also end up in the room. + user_id = self.get_success(self.handler.register_user(localpart="bob")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + @override_config( + { + "auto_join_rooms": ["#room:test"], + "autocreate_auto_join_room_preset": "private_chat", + "auto_join_mxid_localpart": "support", + } + ) + def test_auto_create_auto_join_room_preset_guest(self): + """ + Auto-created rooms that are private require an invite to go to the user + (instead of directly joining it). + + This should also work for guests. + """ + inviter = "@support:test" + + room_alias_str = "#room:test" + user_id = self.get_success( + self.handler.register_user(localpart="jeff", make_guest=True) + ) + + # Ensure the room was created. + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + # Ensure the room is properly a private room. + room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + self.assertFalse(room["public"]) + self.assertEqual(room["join_rules"], "invite") + self.assertEqual(room["guest_access"], "can_join") + + # Both users should be in the room. + rooms = self.get_success(self.store.get_rooms_for_user(inviter)) + self.assertIn(room_id["room_id"], rooms) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + @override_config( + { + "auto_join_rooms": ["#room:test"], + "autocreate_auto_join_room_preset": "private_chat", + "auto_join_mxid_localpart": "support", + } + ) + def test_auto_create_auto_join_room_preset_invalid_permissions(self): + """ + Auto-created rooms that are private require an invite, check that + registration doesn't completely break if the inviter doesn't have proper + permissions. + """ + inviter = "@support:test" + + # Register an initial user to create the room and such (essentially this + # is a subset of test_auto_create_auto_join_room_preset). + room_alias_str = "#room:test" + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + + # Ensure the room was created. + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + # Ensure the room exists. + self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + + # Both users should be in the room. + rooms = self.get_success(self.store.get_rooms_for_user(inviter)) + self.assertIn(room_id["room_id"], rooms) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + # Lower the permissions of the inviter. + event_creation_handler = self.hs.get_event_creation_handler() + requester = create_requester(inviter) + event, context = self.get_success( + event_creation_handler.create_event( + requester, + { + "type": "m.room.power_levels", + "state_key": "", + "room_id": room_id["room_id"], + "content": {"invite": 100, "users": {inviter: 0}}, + "sender": inviter, + }, + ) + ) + self.get_success( + event_creation_handler.send_nonmember_event(requester, event, context) + ) + + # Register a second user, which won't be be in the room (or even have an invite) + # since the inviter no longer has the proper permissions. + user_id = self.get_success(self.handler.register_user(localpart="bob")) + + # This user should not be in any rooms. + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(user_id) + ) + self.assertEqual(rooms, set()) + self.assertEqual(invited_rooms, []) + def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. -- cgit 1.5.1 From 74d3e177f0443f27e670f0b99299d715c58fd238 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 1 Jul 2020 11:08:25 +0100 Subject: Back out MSC2625 implementation (#7761) --- changelog.d/7673.feature | 1 - changelog.d/7716.feature | 1 - changelog.d/7761.feature | 1 + synapse/handlers/sync.py | 3 - synapse/push/bulk_push_rule_evaluator.py | 7 +- synapse/push/push_tools.py | 5 +- synapse/rest/client/v1/push_rule.py | 4 +- .../storage/data_stores/main/event_push_actions.py | 133 +++++---------------- .../delta/58/07push_summary_unread_count.sql | 23 ---- tests/replication/slave/storage/test_events.py | 19 +-- tests/storage/test_event_push_actions.py | 45 +++---- 11 files changed, 53 insertions(+), 189 deletions(-) delete mode 100644 changelog.d/7673.feature delete mode 100644 changelog.d/7716.feature create mode 100644 changelog.d/7761.feature delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql (limited to 'tests') diff --git a/changelog.d/7673.feature b/changelog.d/7673.feature deleted file mode 100644 index ecc3ffd8d5..0000000000 --- a/changelog.d/7673.feature +++ /dev/null @@ -1 +0,0 @@ -Add a per-room counter for unread messages in responses to `/sync` requests. Implements [MSC2625](https://github.com/matrix-org/matrix-doc/pull/2625). diff --git a/changelog.d/7716.feature b/changelog.d/7716.feature deleted file mode 100644 index ecc3ffd8d5..0000000000 --- a/changelog.d/7716.feature +++ /dev/null @@ -1 +0,0 @@ -Add a per-room counter for unread messages in responses to `/sync` requests. Implements [MSC2625](https://github.com/matrix-org/matrix-doc/pull/2625). diff --git a/changelog.d/7761.feature b/changelog.d/7761.feature new file mode 100644 index 0000000000..c97864677a --- /dev/null +++ b/changelog.d/7761.feature @@ -0,0 +1 @@ +Add unread messages count to sync responses. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0b82aa72a6..4c7524493e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1893,9 +1893,6 @@ class SyncHandler(object): if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] unread_notifications["highlight_count"] = notifs["highlight_count"] - unread_notifications["org.matrix.msc2625.unread_count"] = notifs[ - "unread_count" - ] sync_result_builder.joined.append(room_sync) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 5b00602a56..43ffe6faf0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -189,11 +189,8 @@ class BulkPushRuleEvaluator(object): ) if matches: actions = [x for x in rule["actions"] if x != "dont_notify"] - if ( - "notify" in actions - or "org.matrix.msc2625.mark_unread" in actions - ): - # Push rules say we should act on this event. + if actions and "notify" in actions: + # Push rules say we should notify the user of this event actions_by_user[uid] = actions break diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 4ea683fee0..5dae4648c0 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -39,10 +39,7 @@ def get_badge_count(store, user_id): ) # return one badge count per conversation, as count per # message is so noisy as to be almost useless - # We're populating this badge using the unread_count (instead of the - # notify_count) as this badge is the number of missed messages, not the - # number of missed notifications. - badge += 1 if notifs.get("unread_count") else 0 + badge += 1 if notifs["notify_count"] else 0 return badge diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index f563b3dc35..9fd4908136 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2020 The Matrix.org Foundation C.I.C. +# Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -267,7 +267,7 @@ def _check_actions(actions): raise InvalidRuleException("No actions found") for a in actions: - if a in ["notify", "dont_notify", "coalesce", "org.matrix.msc2625.mark_unread"]: + if a in ["notify", "dont_notify", "coalesce"]: pass elif isinstance(a, dict) and "set_tweak" in a: pass diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 815d52ab4c..bc9f4f08ea 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2015-2020 The Matrix.org Foundation C.I.C. +# Copyright 2015 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +15,7 @@ # limitations under the License. import logging -from typing import Dict, Tuple -import attr from canonicaljson import json from twisted.internet import defer @@ -37,16 +36,6 @@ DEFAULT_HIGHLIGHT_ACTION = [ ] -@attr.s -class EventPushSummary: - """Summary of pending event push actions for a given user in a given room.""" - - unread_count = attr.ib(type=int) - stream_ordering = attr.ib(type=int) - old_user_id = attr.ib(type=str) - notif_count = attr.ib(type=int) - - def _serialize_action(actions, is_highlight): """Custom serializer for actions. This allows us to "compress" common actions. @@ -123,7 +112,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (room_id, last_read_event_id)) results = txn.fetchall() if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0, "unread_count": 0} + return {"notify_count": 0, "highlight_count": 0} stream_ordering = results[0][0] @@ -133,42 +122,25 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): - # First get number of actions, grouped on whether the action notifies. + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 sql = ( - "SELECT count(*), notif" + "SELECT count(*)" " FROM event_push_actions ea" " WHERE" " user_id = ?" " AND room_id = ?" " AND stream_ordering > ?" - " GROUP BY notif" ) - txn.execute(sql, (user_id, room_id, stream_ordering)) - rows = txn.fetchall() - # We should get a maximum number of two rows: one for notif = 0, which is the - # number of actions that contribute to the unread_count but not to the - # notify_count, and one for notif = 1, which is the number of actions that - # contribute to both counters. If one or both rows don't appear, then the - # value for the matching counter should be 0. - unread_count = 0 - notify_count = 0 - for row in rows: - # We always increment unread_count because actions that notify also - # contribute to it. - unread_count += row[0] - if row[1] == 1: - notify_count = row[0] - elif row[1] != 0: - logger.warning( - "Unexpected value %d for column 'notif' in table" - " 'event_push_actions'", - row[1], - ) + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + notify_count = row[0] if row else 0 txn.execute( """ - SELECT notif_count, unread_count FROM event_push_summary + SELECT notif_count FROM event_push_summary WHERE room_id = ? AND user_id = ? AND stream_ordering > ? """, (room_id, user_id, stream_ordering), @@ -176,7 +148,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): rows = txn.fetchall() if rows: notify_count += rows[0][0] - unread_count += rows[0][1] # Now get the number of highlights sql = ( @@ -193,11 +164,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): row = txn.fetchone() highlight_count = row[0] if row else 0 - return { - "unread_count": unread_count, - "notify_count": notify_count, - "highlight_count": highlight_count, - } + return {"notify_count": notify_count, "highlight_count": highlight_count} @defer.inlineCallbacks def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): @@ -255,7 +222,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -284,7 +250,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -357,7 +322,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -386,7 +350,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -436,7 +399,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _get_if_maybe_push_in_range_for_user_txn(txn): sql = """ SELECT 1 FROM event_push_actions - WHERE user_id = ? AND stream_ordering > ? AND notif = 1 + WHERE user_id = ? AND stream_ordering > ? LIMIT 1 """ @@ -465,15 +428,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): return # This is a helper function for generating the necessary tuple that - # can be used to insert into the `event_push_actions_staging` table. + # can be used to inert into the `event_push_actions_staging` table. def _gen_entry(user_id, actions): is_highlight = 1 if _action_has_highlight(actions) else 0 - notif = 0 if "org.matrix.msc2625.mark_unread" in actions else 1 return ( event_id, # event_id column user_id, # user_id column _serialize_action(actions, is_highlight), # actions column - notif, # notif column + 1, # notif column is_highlight, # highlight column ) @@ -855,51 +817,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, - coalesce(old.%s, 0) + upd.cnt, + coalesce(old.notif_count, 0) + upd.notif_count, upd.stream_ordering, old.user_id FROM ( - SELECT user_id, room_id, count(*) as cnt, + SELECT user_id, room_id, count(*) as notif_count, max(stream_ordering) as stream_ordering FROM event_push_actions WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0 - %s GROUP BY user_id, room_id ) AS upd LEFT JOIN event_push_summary AS old USING (user_id, room_id) """ - # First get the count of unread messages. - txn.execute( - sql % ("unread_count", ""), - (old_rotate_stream_ordering, rotate_to_stream_ordering), - ) - - # We need to merge both lists into a single object because we might not have the - # same amount of rows in each of them. In this case we use a dict indexed on the - # user ID and room ID to make it easier to populate. - summaries = {} # type: Dict[Tuple[str, str], EventPushSummary] - for row in txn: - summaries[(row[0], row[1])] = EventPushSummary( - unread_count=row[2], - stream_ordering=row[3], - old_user_id=row[4], - notif_count=0, - ) - - # Then get the count of notifications. - txn.execute( - sql % ("notif_count", "AND notif = 1"), - (old_rotate_stream_ordering, rotate_to_stream_ordering), - ) - - # notif_rows is populated based on a subset of the query used to populate - # unread_rows, so we can be sure that there will be no KeyError here. - for row in txn: - summaries[(row[0], row[1])].notif_count = row[2] + txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) + rows = txn.fetchall() - logger.info("Rotating notifications, handling %d rows", len(summaries)) + logger.info("Rotating notifications, handling %d rows", len(rows)) # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the @@ -909,34 +844,22 @@ class EventPushActionsStore(EventPushActionsWorkerStore): table="event_push_summary", values=[ { - "user_id": user_id, - "room_id": room_id, - "notif_count": summary.notif_count, - "unread_count": summary.unread_count, - "stream_ordering": summary.stream_ordering, + "user_id": row[0], + "room_id": row[1], + "notif_count": row[2], + "stream_ordering": row[3], } - for ((user_id, room_id), summary) in summaries.items() - if summary.old_user_id is None + for row in rows + if row[4] is None ], ) txn.executemany( """ - UPDATE event_push_summary - SET notif_count = ?, unread_count = ?, stream_ordering = ? + UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? WHERE user_id = ? AND room_id = ? """, - ( - ( - summary.notif_count, - summary.unread_count, - summary.stream_ordering, - user_id, - room_id, - ) - for ((user_id, room_id), summary) in summaries.items() - if summary.old_user_id is not None - ), + ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), ) txn.execute( diff --git a/synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql b/synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql deleted file mode 100644 index f1459ef7f0..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Store the number of unread messages, i.e. messages that triggered either a notify --- action or a mark_unread one. -ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0; - --- Pre-populate the new column with the count of pending notifications. --- We expect event_push_summary to be relatively small, so we can do this update --- synchronously without impacting Synapse's startup time too much. -UPDATE event_push_summary SET unread_count = notif_count; \ No newline at end of file diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index cd8680e812..1a88c7fb80 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 0, "unread_count": 0}, + {"highlight_count": 0, "notify_count": 0}, ) self.persist( @@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 1, "unread_count": 1}, + {"highlight_count": 0, "notify_count": 1}, ) self.persist( @@ -188,20 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 1, "notify_count": 2, "unread_count": 2}, - ) - - self.persist( - type="m.room.message", - msgtype="m.text", - body="world", - push_actions=[(USER_ID_2, ["org.matrix.msc2625.mark_unread"])], - ) - self.replicate() - self.check( - "get_unread_event_push_actions_by_room_for_user", - [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 1, "notify_count": 2, "unread_count": 3}, + {"highlight_count": 1, "notify_count": 2}, ) def test_get_rooms_for_user_with_stream_ordering(self): diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 303dc8571c..b45bc9c115 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -22,10 +22,6 @@ import tests.utils USER_ID = "@user:example.com" -MARK_UNREAD = [ - "org.matrix.msc2625.mark_unread", - {"set_tweak": "highlight", "value": False}, -] PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}] HIGHLIGHT = [ "notify", @@ -59,17 +55,13 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): user_id = "@user1235:example.com" @defer.inlineCallbacks - def _assert_counts(unread_count, notif_count, highlight_count): + def _assert_counts(noitf_count, highlight_count): counts = yield self.store.db.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( counts, - { - "unread_count": unread_count, - "notify_count": notif_count, - "highlight_count": highlight_count, - }, + {"notify_count": noitf_count, "highlight_count": highlight_count}, ) @defer.inlineCallbacks @@ -104,23 +96,23 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): stream, ) - yield _assert_counts(0, 0, 0) + yield _assert_counts(0, 0) yield _inject_actions(1, PlAIN_NOTIF) - yield _assert_counts(1, 1, 0) + yield _assert_counts(1, 0) yield _rotate(2) - yield _assert_counts(1, 1, 0) + yield _assert_counts(1, 0) yield _inject_actions(3, PlAIN_NOTIF) - yield _assert_counts(2, 2, 0) + yield _assert_counts(2, 0) yield _rotate(4) - yield _assert_counts(2, 2, 0) + yield _assert_counts(2, 0) yield _inject_actions(5, PlAIN_NOTIF) yield _mark_read(3, 3) - yield _assert_counts(1, 1, 0) + yield _assert_counts(1, 0) yield _mark_read(5, 5) - yield _assert_counts(0, 0, 0) + yield _assert_counts(0, 0) yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) @@ -129,22 +121,17 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): table="event_push_actions", keyvalues={"1": 1}, desc="" ) - yield _assert_counts(1, 1, 0) + yield _assert_counts(1, 0) yield _mark_read(7, 7) - yield _assert_counts(0, 0, 0) + yield _assert_counts(0, 0) - yield _inject_actions(8, MARK_UNREAD) - yield _assert_counts(1, 0, 0) + yield _inject_actions(8, HIGHLIGHT) + yield _assert_counts(1, 1) yield _rotate(9) - yield _assert_counts(1, 0, 0) - - yield _inject_actions(10, HIGHLIGHT) - yield _assert_counts(2, 1, 1) - yield _rotate(11) - yield _assert_counts(2, 1, 1) - yield _rotate(12) - yield _assert_counts(2, 1, 1) + yield _assert_counts(1, 1) + yield _rotate(10) + yield _assert_counts(1, 1) @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): -- cgit 1.5.1 From e5808c4cfbec60f11f358bea529b321e94751ec9 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Wed, 1 Jul 2020 17:02:31 +0100 Subject: Hack to add push priority to push notifications (#7765) * Remove obsolete comment about ancient temporary code Signed-off-by: Olivier Wilkinson (reivilibre) * Implement hack to set push priority based on whether the tweaks indicate the event might cause effects. * Changelog for 7765 Signed-off-by: Olivier Wilkinson (reivilibre) * Antilint * Add tests for push priority Signed-off-by: Olivier Wilkinson (reivilibre) * Update synapse/push/httppusher.py Co-authored-by: Brendan Abolivier * Antilint * Remove needless invites from tests. Co-authored-by: Brendan Abolivier --- changelog.d/7765.misc | 1 + synapse/push/httppusher.py | 17 ++- tests/push/test_http.py | 352 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 362 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7765.misc (limited to 'tests') diff --git a/changelog.d/7765.misc b/changelog.d/7765.misc new file mode 100644 index 0000000000..fa9cfd24cb --- /dev/null +++ b/changelog.d/7765.misc @@ -0,0 +1 @@ +Send push notifications with a high or low priority depending upon whether they may generate user-observable effects. diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index ed60dbc1bf..2fac07593b 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -20,6 +20,7 @@ from prometheus_client import Counter from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled +from synapse.api.constants import EventTypes from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException @@ -305,12 +306,23 @@ class HttpPusher(object): @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): + priority = "low" + if ( + event.type == EventTypes.Encrypted + or tweaks.get("highlight") + or tweaks.get("sound") + ): + # HACK send our push as high priority only if it generates a sound, highlight + # or may do so (i.e. is encrypted so has unknown effects). + priority = "high" + if self.data.get("format") == "event_id_only": d = { "notification": { "event_id": event.event_id, "room_id": event.room_id, "counts": {"unread": badge}, + "prio": priority, "devices": [ { "app_id": self.app_id, @@ -334,9 +346,8 @@ class HttpPusher(object): "room_id": event.room_id, "type": event.type, "sender": event.user_id, - "counts": { # -- we don't mark messages as read yet so - # we have no way of knowing - # Just set the badge to 1 until we have read receipts + "prio": priority, + "counts": { "unread": badge, # 'missed_calls': 2 }, diff --git a/tests/push/test_http.py b/tests/push/test_http.py index baf9c785f4..b567868b02 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -25,7 +25,6 @@ from tests.unittest import HomeserverTestCase class HTTPPusherTests(HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -35,7 +34,6 @@ class HTTPPusherTests(HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor, clock): - self.push_attempts = [] m = Mock() @@ -90,9 +88,6 @@ class HTTPPusherTests(HomeserverTestCase): # Create a room room = self.helper.create_room_as(user_id, tok=access_token) - # Invite the other person - self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) - # The other user joins self.helper.join(room=room, user=other_user_id, tok=other_access_token) @@ -157,3 +152,350 @@ class HTTPPusherTests(HomeserverTestCase): pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + + def test_sends_high_priority_for_encrypted(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to an encrypted message. + This will happen both in 1:1 rooms and larger rooms. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send an encrypted event + # I know there'd normally be set-up of an encrypted room first + # but this will do for our purposes + self.helper.send_event( + room, + "m.room.encrypted", + content={ + "algorithm": "m.megolm.v1.aes-sha2", + "sender_key": "6lImKbzK51MzWLwHh8tUM3UBBSBrLlgup/OOCGTvumM", + "ciphertext": "AwgAErABoRxwpMipdgiwXgu46rHiWQ0DmRj0qUlPrMraBUDk" + "leTnJRljpuc7IOhsYbLY3uo2WI0ab/ob41sV+3JEIhODJPqH" + "TK7cEZaIL+/up9e+dT9VGF5kRTWinzjkeqO8FU5kfdRjm+3w" + "0sy3o1OCpXXCfO+faPhbV/0HuK4ndx1G+myNfK1Nk/CxfMcT" + "BT+zDS/Df/QePAHVbrr9uuGB7fW8ogW/ulnydgZPRluusFGv" + "J3+cg9LoPpZPAmv5Me3ec7NtdlfN0oDZ0gk3TiNkkhsxDG9Y" + "YcNzl78USI0q8+kOV26Bu5dOBpU4WOuojXZHJlP5lMgdzLLl" + "EQ0", + "session_id": "IigqfNWLL+ez/Is+Duwp2s4HuCZhFG9b9CZKTYHtQ4A", + "device_id": "AHQDUSTAAA", + }, + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Add yet another person — we want to make this room not a 1:1 + # (as encrypted messages in a 1:1 currently have tweaks applied + # so it doesn't properly exercise the condition of all encrypted + # messages need to be high). + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Check no push notifications are sent regarding the membership changes + # (that would confuse the test) + self.pump() + self.assertEqual(len(self.push_attempts), 1) + + # Send another encrypted event + self.helper.send_event( + room, + "m.room.encrypted", + content={ + "ciphertext": "AwgAEoABtEuic/2DF6oIpNH+q/PonzlhXOVho8dTv0tzFr5m" + "9vTo50yabx3nxsRlP2WxSqa8I07YftP+EKWCWJvTkg6o7zXq" + "6CK+GVvLQOVgK50SfvjHqJXN+z1VEqj+5mkZVN/cAgJzoxcH" + "zFHkwDPJC8kQs47IHd8EO9KBUK4v6+NQ1uE/BIak4qAf9aS/" + "kI+f0gjn9IY9K6LXlah82A/iRyrIrxkCkE/n0VfvLhaWFecC" + "sAWTcMLoF6fh1Jpke95mljbmFSpsSd/eEQw", + "device_id": "SRCFTWTHXO", + "session_id": "eMA+bhGczuTz1C5cJR1YbmrnnC6Goni4lbvS5vJ1nG4", + "algorithm": "m.megolm.v1.aes-sha2", + "sender_key": "rC/XSIAiYrVGSuaHMop8/pTZbku4sQKBZwRwukgnN1c", + }, + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") + + def test_sends_high_priority_for_one_to_one_only(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message in a one-to-one room. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send(room, body="Hi!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority — this is a one-to-one room + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Yet another user joins + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Check no push notifications are sent regarding the membership changes + # (that would confuse the test) + self.pump() + self.assertEqual(len(self.push_attempts), 1) + + # Send another event + self.helper.send(room, body="Welcome!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + + def test_sends_high_priority_for_mention(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message containing the user's display name. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other users join + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send(room, body="Oh, user, hello!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Send another event, this time with no mention + self.helper.send(room, body="Are you there?", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + + def test_sends_high_priority_for_atroom(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message that contains @room. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room (as other_user so the power levels are compatible with + # other_user sending @room). + room = self.helper.create_room_as(other_user_id, tok=other_access_token) + + # The other users join + self.helper.join(room=room, user=user_id, tok=access_token) + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send( + room, + body="@room eeek! There's a spider on the table!", + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Send another event, this time as someone without the power of @room + self.helper.send( + room, body="@room the spider is gone", tok=yet_another_access_token + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") -- cgit 1.5.1 From 21a212f8e50343e9b55944fa75ece7911fd2cb70 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 3 Jul 2020 15:03:13 +0200 Subject: Fix inconsistent handling of upper and lower cases of email addresses. (#7021) fixes #7016 --- changelog.d/7021.bugfix | 1 + synapse/handlers/auth.py | 5 +- synapse/rest/client/v1/login.py | 12 +- synapse/rest/client/v2_alpha/account.py | 40 +++++-- synapse/rest/client/v2_alpha/register.py | 22 +++- synapse/util/threepids.py | 23 ++++ tests/rest/client/v2_alpha/test_account.py | 175 ++++++++++++++++++++++++----- tests/util/test_threepids.py | 49 ++++++++ 8 files changed, 279 insertions(+), 48 deletions(-) create mode 100644 changelog.d/7021.bugfix create mode 100644 tests/util/test_threepids.py (limited to 'tests') diff --git a/changelog.d/7021.bugfix b/changelog.d/7021.bugfix new file mode 100644 index 0000000000..140fe37b2d --- /dev/null +++ b/changelog.d/7021.bugfix @@ -0,0 +1 @@ +Fix inconsistent handling of upper and lower case in email addresses when used as identifiers for login, etc. Contributed by @dklimpel. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c3f86e7414..d713a06bf9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -45,6 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID +from synapse.util.threepids import canonicalise_email from ._base import BaseHandler @@ -928,7 +929,7 @@ class AuthHandler(BaseHandler): # for the presence of an email address during password reset was # case sensitive). if medium == "email": - address = address.lower() + address = canonicalise_email(address) await self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() @@ -956,7 +957,7 @@ class AuthHandler(BaseHandler): # 'Canonicalise' email addresses as per above if medium == "email": - address = address.lower() + address = canonicalise_email(address) identity_handler = self.hs.get_handlers().identity_handler result = await identity_handler.try_unbind_threepid( diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index bf0f9bd077..f6eef7afee 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -28,6 +28,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.threepids import canonicalise_email logger = logging.getLogger(__name__) @@ -206,11 +207,14 @@ class LoginRestServlet(RestServlet): if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) if medium == "email": - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. - # (See add_threepid in synapse/handlers/auth.py) - address = address.lower() + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 182a308eef..3767a809a4 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -30,7 +30,7 @@ from synapse.http.servlet import ( from synapse.push.mailer import Mailer, load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import check_3pid_allowed +from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -83,7 +83,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): client_secret = body["client_secret"] assert_valid_client_secret(client_secret) - email = body["email"] + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. This allows the user to reset his password without having to + # know the exact spelling (eg. upper and lower case) of address in the database. + # Stored in the database "foo@bar.com" + # User requests with "FOO@bar.com" would raise a Not Found error + try: + email = canonicalise_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -94,6 +102,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + # The email will be sent to the stored address. + # This avoids a potential account hijack by requesting a password reset to + # an email address which is controlled by the attacker but which, after + # canonicalisation, matches the one in our database. existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -274,10 +286,13 @@ class PasswordRestServlet(RestServlet): if "medium" not in threepid or "address" not in threepid: raise SynapseError(500, "Malformed threepid") if threepid["medium"] == "email": - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. # (See add_threepid in synapse/handlers/auth.py) - threepid["address"] = threepid["address"].lower() + try: + threepid["address"] = canonicalise_email(threepid["address"]) + except ValueError as e: + raise SynapseError(400, str(e)) # if using email, we must know about the email they're authing with! threepid_user_id = await self.datastore.get_user_id_by_threepid( threepid["medium"], threepid["address"] @@ -392,7 +407,16 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): client_secret = body["client_secret"] assert_valid_client_secret(client_secret) - email = body["email"] + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. + # This ensures that the validation email is sent to the canonicalised address + # as it will later be entered into the database. + # Otherwise the email will be sent to "FOO@bar.com" and stored as + # "foo@bar.com" in database. + try: + email = canonicalise_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -403,9 +427,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = await self.store.get_user_id_by_threepid( - "email", body["email"] - ) + existing_user_id = await self.store.get_user_id_by_threepid("email", email) if existing_user_id is not None: if self.config.request_token_inhibit_3pid_errors: diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 56a451c42f..370742ce59 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -47,7 +47,7 @@ from synapse.push.mailer import load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import check_3pid_allowed +from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -116,7 +116,14 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): client_secret = body["client_secret"] assert_valid_client_secret(client_secret) - email = body["email"] + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See on_POST in EmailThreepidRequestTokenRestServlet + # in synapse/rest/client/v2_alpha/account.py) + try: + email = canonicalise_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -128,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( - "email", body["email"] + "email", email ) if existing_user_id is not None: @@ -552,6 +559,15 @@ class RegisterRestServlet(RestServlet): if login_type in auth_result: medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See on_POST in EmailThreepidRequestTokenRestServlet + # in synapse/rest/client/v2_alpha/account.py) + if medium == "email": + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) existing_user_id = await self.store.get_user_id_by_threepid( medium, address diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 3ec1dfb0c2..43c2e0ac23 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -48,3 +48,26 @@ def check_3pid_allowed(hs, medium, address): return True return False + + +def canonicalise_email(address: str) -> str: + """'Canonicalise' email address + Case folding of local part of email address and lowercase domain part + See MSC2265, https://github.com/matrix-org/matrix-doc/pull/2265 + + Args: + address: email address to be canonicalised + Returns: + The canonical form of the email address + Raises: + ValueError if the address could not be parsed. + """ + + address = address.strip() + + parts = address.split("@") + if len(parts) != 2: + logger.debug("Couldn't parse email address %s", address) + raise ValueError("Unable to parse email address") + + return parts[0].casefold() + "@" + parts[1].lower() diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 3ab611f618..152a5182fa 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -108,6 +108,46 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) + def test_basic_password_reset_canonicalise_email(self): + """Test basic password reset flow + Request password reset with different spelling + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email_profile = "test@example.com" + email_passwort_reset = "TEST@EXAMPLE.COM" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email_profile, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email_passwort_reset, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + # Assert we can log in with the new password + self.login("kermit", new_password) + + # Assert we can't log in with the old password + self.attempt_wrong_password_login("kermit", old_password) + def test_cant_reset_password_without_clicking_link(self): """Test that we do actually need to click the link in the email """ @@ -386,44 +426,67 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.email = "test@example.com" self.url_3pid = b"account/3pid" - def test_add_email(self): - """Test adding an email to profile - """ - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) + def test_add_valid_email(self): + self.get_success(self._add_email(self.email, self.email)) - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() + def test_add_valid_email_second_time(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + self.email, + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) - self._validate_token(link) + def test_add_valid_email_second_time_canonicalise(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + "TEST@EXAMPLE.COM", + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, + def test_add_email_no_at(self): + self.get_success( + self._request_token_invalid_email( + "address-without-at.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def test_add_email_two_at(self): + self.get_success( + self._request_token_invalid_email( + "foo@foo@test.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + def test_add_email_bad_format(self): + self.get_success( + self._request_token_invalid_email( + "user@bad.example.net@good.example.com", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + def test_add_email_domain_to_lower(self): + self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) + + def test_add_email_domain_with_umlaut(self): + self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) + + def test_add_email_address_casefold(self): + self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) + + def test_address_trim(self): + self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed @@ -616,6 +679,19 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): return channel.json_body["sid"] + def _request_token_invalid_email( + self, email, expected_errcode, expected_error, client_secret="foobar", + ): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(expected_errcode, channel.json_body["errcode"]) + self.assertEqual(expected_error, channel.json_body["error"]) + def _validate_token(self, link): # Remove the host path = link.replace("https://example.com", "") @@ -643,3 +719,42 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): assert match, "Could not find link in email" return match.group(0) + + def _add_email(self, request_email, expected_email): + """Test adding an email to profile + """ + client_secret = "foobar" + session_id = self._request_token(request_email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"]) diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py new file mode 100644 index 0000000000..5513724d87 --- /dev/null +++ b/tests/util/test_threepids.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.threepids import canonicalise_email + +from tests.unittest import HomeserverTestCase + + +class CanonicaliseEmailTests(HomeserverTestCase): + def test_no_at(self): + with self.assertRaises(ValueError): + canonicalise_email("address-without-at.bar") + + def test_two_at(self): + with self.assertRaises(ValueError): + canonicalise_email("foo@foo@test.bar") + + def test_bad_format(self): + with self.assertRaises(ValueError): + canonicalise_email("user@bad.example.net@good.example.com") + + def test_valid_format(self): + self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar") + + def test_domain_to_lower(self): + self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar") + + def test_domain_with_umlaut(self): + self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com") + + def test_address_casefold(self): + self.assertEqual( + canonicalise_email("Strauß@Example.com"), "strauss@example.com" + ) + + def test_address_trim(self): + self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar") -- cgit 1.5.1 From 5cdca53aa07f921029cb8027693095d150c37e32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Jul 2020 19:02:19 +0100 Subject: Merge different Resource implementation classes (#7732) --- changelog.d/7732.bugfix | 1 + synapse/federation/transport/server.py | 6 +- synapse/http/additional_resource.py | 19 +- synapse/http/server.py | 365 ++++++++++++++------------ synapse/logging/opentracing.py | 68 +++-- synapse/replication/http/__init__.py | 3 +- synapse/replication/http/_base.py | 11 +- synapse/rest/consent/consent_resource.py | 10 +- synapse/rest/key/v2/remote_key_resource.py | 12 +- synapse/rest/media/v1/config_resource.py | 14 +- synapse/rest/media/v1/download_resource.py | 12 +- synapse/rest/media/v1/preview_url_resource.py | 10 +- synapse/rest/media/v1/thumbnail_resource.py | 10 +- synapse/rest/media/v1/upload_resource.py | 14 +- synapse/rest/oidc/callback_resource.py | 7 +- synapse/rest/saml2/response_resource.py | 4 +- tests/http/test_additional_resource.py | 62 +++++ tests/test_server.py | 12 +- 18 files changed, 322 insertions(+), 318 deletions(-) create mode 100644 changelog.d/7732.bugfix create mode 100644 tests/http/test_additional_resource.py (limited to 'tests') diff --git a/changelog.d/7732.bugfix b/changelog.d/7732.bugfix new file mode 100644 index 0000000000..d5e352e141 --- /dev/null +++ b/changelog.d/7732.bugfix @@ -0,0 +1 @@ +Fix "Tried to close a non-active scope!" error messages when opentracing is enabled. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index af4595498c..bfb7831a02 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -361,11 +361,7 @@ class BaseFederationServlet(object): continue server.register_paths( - method, - (pattern,), - self._wrap(code), - self.__class__.__name__, - trace=False, + method, (pattern,), self._wrap(code), self.__class__.__name__, ) diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 096619a8c2..479746c9c5 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -13,13 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET +from synapse.http.server import DirectServeJsonResource -from synapse.http.server import wrap_json_request_handler - -class AdditionalResource(Resource): +class AdditionalResource(DirectServeJsonResource): """Resource wrapper for additional_resources If the user has configured additional_resources, we need to wrap the @@ -41,16 +38,10 @@ class AdditionalResource(Resource): handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): function to be called to handle the request. """ - Resource.__init__(self) + super().__init__() self._handler = handler - # required by the request_handler wrapper - self.clock = hs.get_clock() - - def render(self, request): - self._async_render(request) - return NOT_DONE_YET - - @wrap_json_request_handler def _async_render(self, request): + # Cheekily pass the result straight through, so we don't need to worry + # if its an awaitable or not. return self._handler(request) diff --git a/synapse/http/server.py b/synapse/http/server.py index d192de7923..2b35f86066 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import collections import html import logging @@ -21,7 +22,7 @@ import types import urllib from http import HTTPStatus from io import BytesIO -from typing import Awaitable, Callable, TypeVar, Union +from typing import Any, Callable, Dict, Tuple, Union import jinja2 from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json @@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """ """ -def wrap_json_request_handler(h): - """Wraps a request handler method with exception handling. - - Also does the wrapping with request.processing as per wrap_async_request_handler. - - The handler method must have a signature of "handle_foo(self, request)", - where "request" must be a SynapseRequest. - - The handler must return a deferred or a coroutine. If the deferred succeeds - we assume that a response has been sent. If the deferred fails with a SynapseError we use - it to send a JSON response with the appropriate HTTP reponse code. If the - deferred fails with any other type of error we send a 500 reponse. +def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: + """Sends a JSON error response to clients. """ - async def wrapped_request_handler(self, request): - try: - await h(self, request) - except SynapseError as e: - code = e.code - logger.info("%s SynapseError: %s - %s", request, code, e.msg) - - # Only respond with an error response if we haven't already started - # writing, otherwise lets just kill the connection - if request.startedWriting: - if request.transport: - try: - request.transport.abortConnection() - except Exception: - # abortConnection throws if the connection is already closed - pass - else: - respond_with_json( - request, - code, - e.error_dict(), - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - ) - - except Exception: - # failure.Failure() fishes the original Failure out - # of our stack, and thus gives us a sensible stack - # trace. - f = failure.Failure() - logger.error( - "Failed handle request via %r: %r", - request.request_metrics.name, - request, - exc_info=(f.type, f.value, f.getTracebackObject()), - ) - # Only respond with an error response if we haven't already started - # writing, otherwise lets just kill the connection - if request.startedWriting: - if request.transport: - try: - request.transport.abortConnection() - except Exception: - # abortConnection throws if the connection is already closed - pass - else: - respond_with_json( - request, - 500, - {"error": "Internal server error", "errcode": Codes.UNKNOWN}, - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - ) - - return wrap_async_request_handler(wrapped_request_handler) - - -TV = TypeVar("TV") - - -def wrap_html_request_handler( - h: Callable[[TV, SynapseRequest], Awaitable] -) -> Callable[[TV, SynapseRequest], Awaitable[None]]: - """Wraps a request handler method with exception handling. + if f.check(SynapseError): + error_code = f.value.code + error_dict = f.value.error_dict() - Also does the wrapping with request.processing as per wrap_async_request_handler. - - The handler method must have a signature of "handle_foo(self, request)", - where "request" must be a SynapseRequest. - """ + logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg) + else: + error_code = 500 + error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN} - async def wrapped_request_handler(self, request): - try: - await h(self, request) - except Exception: - f = failure.Failure() - return_html_error(f, request, HTML_ERROR_TEMPLATE) + logger.error( + "Failed handle request via %r: %r", + request.request_metrics.name, + request, + exc_info=(f.type, f.value, f.getTracebackObject()), + ) - return wrap_async_request_handler(wrapped_request_handler) + # Only respond with an error response if we haven't already started writing, + # otherwise lets just kill the connection + if request.startedWriting: + if request.transport: + try: + request.transport.abortConnection() + except Exception: + # abortConnection throws if the connection is already closed + pass + else: + respond_with_json( + request, + error_code, + error_dict, + send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + ) def return_html_error( @@ -249,7 +194,113 @@ class HttpServer(object): pass -class JsonResource(HttpServer, resource.Resource): +class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): + """Base class for resources that have async handlers. + + Sub classes can either implement `_async_render_` to handle + requests by method, or override `_async_render` to handle all requests. + + Args: + extract_context: Whether to attempt to extract the opentracing + context from the request the servlet is handling. + """ + + def __init__(self, extract_context=False): + super().__init__() + + self._extract_context = extract_context + + def render(self, request): + """ This gets called by twisted every time someone sends us a request. + """ + defer.ensureDeferred(self._async_render_wrapper(request)) + return NOT_DONE_YET + + @wrap_async_request_handler + async def _async_render_wrapper(self, request): + """This is a wrapper that delegates to `_async_render` and handles + exceptions, return values, metrics, etc. + """ + try: + request.request_metrics.name = self.__class__.__name__ + + with trace_servlet(request, self._extract_context): + callback_return = await self._async_render(request) + + if callback_return is not None: + code, response = callback_return + self._send_response(request, code, response) + except Exception: + # failure.Failure() fishes the original Failure out + # of our stack, and thus gives us a sensible stack + # trace. + f = failure.Failure() + self._send_error_response(f, request) + + async def _async_render(self, request): + """Delegates to `_async_render_` methods, or returns a 400 if + no appropriate method exists. Can be overriden in sub classes for + different routing. + """ + + method_handler = getattr( + self, "_async_render_%s" % (request.method.decode("ascii"),), None + ) + if method_handler: + raw_callback_return = method_handler(request) + + # Is it synchronous? We'll allow this for now. + if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): + callback_return = await raw_callback_return + else: + callback_return = raw_callback_return + + return callback_return + + _unrecognised_request_handler(request) + + @abc.abstractmethod + def _send_response( + self, request: SynapseRequest, code: int, response_object: Any, + ) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def _send_error_response( + self, f: failure.Failure, request: SynapseRequest, + ) -> None: + raise NotImplementedError() + + +class DirectServeJsonResource(_AsyncResource): + """A resource that will call `self._async_on_` on new requests, + formatting responses and errors as JSON. + """ + + def _send_response( + self, request, code, response_object, + ): + """Implements _AsyncResource._send_response + """ + # TODO: Only enable CORS for the requests that need it. + respond_with_json( + request, + code, + response_object, + send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + canonical_json=self.canonical_json, + ) + + def _send_error_response( + self, f: failure.Failure, request: SynapseRequest, + ) -> None: + """Implements _AsyncResource._send_error_response + """ + return_json_error(f, request) + + +class JsonResource(DirectServeJsonResource): """ This implements the HttpServer interface and provides JSON support for Resources. @@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource): "_PathEntry", ["pattern", "callback", "servlet_classname"] ) - def __init__(self, hs, canonical_json=True): - resource.Resource.__init__(self) + def __init__(self, hs, canonical_json=True, extract_context=False): + super().__init__(extract_context) self.canonical_json = canonical_json self.clock = hs.get_clock() self.path_regexs = {} self.hs = hs - def register_paths( - self, method, path_patterns, callback, servlet_classname, trace=True - ): + def register_paths(self, method, path_patterns, callback, servlet_classname): """ Registers a request handler against a regular expression. Later request URLs are checked against these regular expressions in order to identify an appropriate @@ -295,37 +344,42 @@ class JsonResource(HttpServer, resource.Resource): servlet_classname (str): The name of the handler to be used in prometheus and opentracing logs. - - trace (bool): Whether we should start a span to trace the servlet. """ method = method.encode("utf-8") # method is bytes on py3 - if trace: - # We don't extract the context from the servlet because we can't - # trust the sender - callback = trace_servlet(servlet_classname)(callback) - for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( self._PathEntry(path_pattern, callback, servlet_classname) ) - def render(self, request): - """ This gets called by twisted every time someone sends us a request. + def _get_handler_for_request( + self, request: SynapseRequest + ) -> Tuple[Callable, str, Dict[str, str]]: + """Finds a callback method to handle the given request. + + Returns: + A tuple of the callback to use, the name of the servlet, and the + key word arguments to pass to the callback """ - defer.ensureDeferred(self._async_render(request)) - return NOT_DONE_YET + request_path = request.path.decode("ascii") + + # Loop through all the registered callbacks to check if the method + # and path regex match + for path_entry in self.path_regexs.get(request.method, []): + m = path_entry.pattern.match(request_path) + if m: + # We found a match! + return path_entry.callback, path_entry.servlet_classname, m.groupdict() + + # Huh. No one wanted to handle that? Fiiiiiine. Send 400. + return _unrecognised_request_handler, "unrecognised_request_handler", {} - @wrap_json_request_handler async def _async_render(self, request): - """ This gets called from render() every time someone sends us a request. - This checks if anyone has registered a callback for that method and - path. - """ callback, servlet_classname, group_dict = self._get_handler_for_request(request) - # Make sure we have a name for this handler in prometheus. + # Make sure we have an appopriate name for this handler in prometheus + # (rather than the default of JsonResource). request.request_metrics.name = servlet_classname # Now trigger the callback. If it returns a response, we send it @@ -338,81 +392,42 @@ class JsonResource(HttpServer, resource.Resource): } ) - callback_return = callback(request, **kwargs) + raw_callback_return = callback(request, **kwargs) # Is it synchronous? We'll allow this for now. - if isinstance(callback_return, (defer.Deferred, types.CoroutineType)): - callback_return = await callback_return + if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): + callback_return = await raw_callback_return + else: + callback_return = raw_callback_return - if callback_return is not None: - code, response = callback_return - self._send_response(request, code, response) + return callback_return - def _get_handler_for_request(self, request): - """Finds a callback method to handle the given request - Args: - request (twisted.web.http.Request): +class DirectServeHtmlResource(_AsyncResource): + """A resource that will call `self._async_on_` on new requests, + formatting responses and errors as HTML. + """ - Returns: - Tuple[Callable, str, dict[unicode, unicode]]: callback method, the - label to use for that method in prometheus metrics, and the - dict mapping keys to path components as specified in the - handler's path match regexp. - - The callback will normally be a method registered via - register_paths, so will return (possibly via Deferred) either - None, or a tuple of (http code, response body). - """ - request_path = request.path.decode("ascii") - - # Loop through all the registered callbacks to check if the method - # and path regex match - for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request_path) - if m: - # We found a match! - return path_entry.callback, path_entry.servlet_classname, m.groupdict() - - # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - return _unrecognised_request_handler, "unrecognised_request_handler", {} + # The error template to use for this resource + ERROR_TEMPLATE = HTML_ERROR_TEMPLATE def _send_response( - self, request, code, response_json_object, response_code_message=None + self, request: SynapseRequest, code: int, response_object: Any, ): - # TODO: Only enable CORS for the requests that need it. - respond_with_json( - request, - code, - response_json_object, - send_cors=True, - response_code_message=response_code_message, - pretty_print=_request_user_agent_is_curl(request), - canonical_json=self.canonical_json, - ) - - -class DirectServeResource(resource.Resource): - def render(self, request): + """Implements _AsyncResource._send_response """ - Render the request, using an asynchronous render handler if it exists. - """ - async_render_callback_name = "_async_render_" + request.method.decode("ascii") - - # Try and get the async renderer - callback = getattr(self, async_render_callback_name, None) + # We expect to get bytes for us to write + assert isinstance(response_object, bytes) + html_bytes = response_object - # No async renderer for this request method. - if not callback: - return super().render(request) + respond_with_html_bytes(request, 200, html_bytes) - resp = trace_servlet(self.__class__.__name__)(callback)(request) - - # If it's a coroutine, turn it into a Deferred - if isinstance(resp, types.CoroutineType): - defer.ensureDeferred(resp) - - return NOT_DONE_YET + def _send_error_response( + self, f: failure.Failure, request: SynapseRequest, + ) -> None: + """Implements _AsyncResource._send_error_response + """ + return_html_error(f, request, self.ERROR_TEMPLATE) class StaticResource(File): diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 73bef5e5ca..1676771ef0 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -169,7 +169,6 @@ import contextlib import inspect import logging import re -import types from functools import wraps from typing import TYPE_CHECKING, Dict, Optional, Type @@ -182,6 +181,7 @@ from synapse.config import ConfigError if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.http.site import SynapseRequest # Helper class @@ -793,48 +793,42 @@ def tag_args(func): return _tag_args_inner -def trace_servlet(servlet_name, extract_context=False): - """Decorator which traces a serlet. It starts a span with some servlet specific - tags such as the servlet_name and request information +@contextlib.contextmanager +def trace_servlet(request: "SynapseRequest", extract_context: bool = False): + """Returns a context manager which traces a request. It starts a span + with some servlet specific tags such as the request metrics name and + request information. Args: - servlet_name (str): The name to be used for the span's operation_name - extract_context (bool): Whether to attempt to extract the opentracing + request + extract_context: Whether to attempt to extract the opentracing context from the request the servlet is handling. - """ - def _trace_servlet_inner_1(func): - if not opentracing: - return func - - @wraps(func) - async def _trace_servlet_inner(request, *args, **kwargs): - request_tags = { - "request_id": request.get_request_id(), - tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, - tags.HTTP_METHOD: request.get_method(), - tags.HTTP_URL: request.get_redacted_uri(), - tags.PEER_HOST_IPV6: request.getClientIP(), - } - - if extract_context: - scope = start_active_span_from_request( - request, servlet_name, tags=request_tags - ) - else: - scope = start_active_span(servlet_name, tags=request_tags) - - with scope: - result = func(request, *args, **kwargs) + if opentracing is None: + yield + return - if not isinstance(result, (types.CoroutineType, defer.Deferred)): - # Some servlets aren't async and just return results - # directly, so we handle that here. - return result + request_tags = { + "request_id": request.get_request_id(), + tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, + tags.HTTP_METHOD: request.get_method(), + tags.HTTP_URL: request.get_redacted_uri(), + tags.PEER_HOST_IPV6: request.getClientIP(), + } - return await result + request_name = request.request_metrics.name + if extract_context: + scope = start_active_span_from_request(request, request_name, tags=request_tags) + else: + scope = start_active_span(request_name, tags=request_tags) - return _trace_servlet_inner + with scope: + try: + yield + finally: + # We set the operation name again in case its changed (which happens + # with JsonResource). + scope.span.set_operation_name(request.request_metrics.name) - return _trace_servlet_inner_1 + scope.span.set_tag("request_tag", request.request_metrics.start_context.tag) diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 19b69e0e11..5ef1c6c1dc 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication" class ReplicationRestResource(JsonResource): def __init__(self, hs): - JsonResource.__init__(self, hs, canonical_json=False) + # We enable extracting jaeger contexts here as these are internal APIs. + super().__init__(hs, canonical_json=False, extract_context=True) self.register_servlets(hs) def register_servlets(self, hs): diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 9caf1e80c1..0843d28d4b 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -28,11 +28,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.logging.opentracing import ( - inject_active_span_byte_dict, - trace, - trace_servlet, -) +from synapse.logging.opentracing import inject_active_span_byte_dict, trace from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string @@ -240,11 +236,8 @@ class ReplicationEndpoint(object): args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) - handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler) - # We don't let register paths trace this servlet using the default tracing - # options because we wish to extract the context explicitly. http_server.register_paths( - method, [pattern], handler, self.__class__.__name__, trace=False + method, [pattern], handler, self.__class__.__name__, ) def _cached_handler(self, request, txn_id, **kwargs): diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 0a890c98cb..4386eb4e72 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -26,11 +26,7 @@ from twisted.internet import defer from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.config import ConfigError -from synapse.http.server import ( - DirectServeResource, - respond_with_html, - wrap_html_request_handler, -) +from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.servlet import parse_string from synapse.types import UserID @@ -48,7 +44,7 @@ else: return a == b -class ConsentResource(DirectServeResource): +class ConsentResource(DirectServeHtmlResource): """A twisted Resource to display a privacy policy and gather consent to it When accessed via GET, returns the privacy policy via a template. @@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource): self._hmac_secret = hs.config.form_secret.encode("utf-8") - @wrap_html_request_handler async def _async_render_GET(self, request): """ Args: @@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource): except TemplateNotFound: raise NotFoundError("Unknown policy version") - @wrap_html_request_handler async def _async_render_POST(self, request): """ Args: diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index ab671f7334..e149ac1733 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -20,17 +20,13 @@ from signedjson.sign import sign_json from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import ( - DirectServeResource, - respond_with_json_bytes, - wrap_json_request_handler, -) +from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes from synapse.http.servlet import parse_integer, parse_json_object_from_request logger = logging.getLogger(__name__) -class RemoteKey(DirectServeResource): +class RemoteKey(DirectServeJsonResource): """HTTP resource for retreiving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks @@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource): isLeaf = True def __init__(self, hs): + super().__init__() + self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.config = hs.config - @wrap_json_request_handler async def _async_render_GET(self, request): if len(request.postpath) == 1: (server,) = request.postpath @@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource): await self.query_keys(request, query, query_remote_on_cache_miss=True) - @wrap_json_request_handler async def _async_render_POST(self, request): content = parse_json_object_from_request(request) diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 9f747de263..68dd2a1c8a 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -14,16 +14,10 @@ # limitations under the License. # -from twisted.web.server import NOT_DONE_YET +from synapse.http.server import DirectServeJsonResource, respond_with_json -from synapse.http.server import ( - DirectServeResource, - respond_with_json, - wrap_json_request_handler, -) - -class MediaConfigResource(DirectServeResource): +class MediaConfigResource(DirectServeJsonResource): isLeaf = True def __init__(self, hs): @@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource): self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.max_upload_size} - @wrap_json_request_handler async def _async_render_GET(self, request): await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) - def render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request): respond_with_json(request, 200, {}, send_cors=True) - return NOT_DONE_YET diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 24d3ae5bbc..d3d8457303 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -15,18 +15,14 @@ import logging import synapse.http.servlet -from synapse.http.server import ( - DirectServeResource, - set_cors_headers, - wrap_json_request_handler, -) +from synapse.http.server import DirectServeJsonResource, set_cors_headers from ._base import parse_media_id, respond_404 logger = logging.getLogger(__name__) -class DownloadResource(DirectServeResource): +class DownloadResource(DirectServeJsonResource): isLeaf = True def __init__(self, hs, media_repo): @@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource): self.media_repo = media_repo self.server_name = hs.hostname - # this is expected by @wrap_json_request_handler - self.clock = hs.get_clock() - - @wrap_json_request_handler async def _async_render_GET(self, request): set_cors_headers(request) request.setHeader( diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index b4645cd608..e52c86c798 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient from synapse.http.server import ( - DirectServeResource, + DirectServeJsonResource, respond_with_json, respond_with_json_bytes, - wrap_json_request_handler, ) from synapse.http.servlet import parse_integer, parse_string from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50 OG_TAG_VALUE_MAXLEN = 1000 -class PreviewUrlResource(DirectServeResource): +class PreviewUrlResource(DirectServeJsonResource): isLeaf = True def __init__(self, hs, media_repo, media_storage): @@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource): self._start_expire_url_cache_data, 10 * 1000 ) - def render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request): request.setHeader(b"Allow", b"OPTIONS, GET") - return respond_with_json(request, 200, {}, send_cors=True) + respond_with_json(request, 200, {}, send_cors=True) - @wrap_json_request_handler async def _async_render_GET(self, request): # XXX: if get_user_by_req fails, what should we do in an async render? diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 0b87220234..a83535b97b 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,11 +16,7 @@ import logging -from synapse.http.server import ( - DirectServeResource, - set_cors_headers, - wrap_json_request_handler, -) +from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string from ._base import ( @@ -34,7 +30,7 @@ from ._base import ( logger = logging.getLogger(__name__) -class ThumbnailResource(DirectServeResource): +class ThumbnailResource(DirectServeJsonResource): isLeaf = True def __init__(self, hs, media_repo, media_storage): @@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource): self.media_storage = media_storage self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname - self.clock = hs.get_clock() - @wrap_json_request_handler async def _async_render_GET(self, request): set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 83d005812d..3ebf7a68e6 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -15,20 +15,14 @@ import logging -from twisted.web.server import NOT_DONE_YET - from synapse.api.errors import Codes, SynapseError -from synapse.http.server import ( - DirectServeResource, - respond_with_json, - wrap_json_request_handler, -) +from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string logger = logging.getLogger(__name__) -class UploadResource(DirectServeResource): +class UploadResource(DirectServeJsonResource): isLeaf = True def __init__(self, hs, media_repo): @@ -43,11 +37,9 @@ class UploadResource(DirectServeResource): self.max_upload_size = hs.config.max_upload_size self.clock = hs.get_clock() - def render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request): respond_with_json(request, 200, {}, send_cors=True) - return NOT_DONE_YET - @wrap_json_request_handler async def _async_render_POST(self, request): requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py index c03194f001..f7a0bc4bdb 100644 --- a/synapse/rest/oidc/callback_resource.py +++ b/synapse/rest/oidc/callback_resource.py @@ -14,18 +14,17 @@ # limitations under the License. import logging -from synapse.http.server import DirectServeResource, wrap_html_request_handler +from synapse.http.server import DirectServeHtmlResource logger = logging.getLogger(__name__) -class OIDCCallbackResource(DirectServeResource): +class OIDCCallbackResource(DirectServeHtmlResource): isLeaf = 1 def __init__(self, hs): super().__init__() self._oidc_handler = hs.get_oidc_handler() - @wrap_html_request_handler async def _async_render_GET(self, request): - return await self._oidc_handler.handle_oidc_callback(request) + await self._oidc_handler.handle_oidc_callback(request) diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index 75e58043b4..c10188a5d7 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -16,10 +16,10 @@ from twisted.python import failure from synapse.api.errors import SynapseError -from synapse.http.server import DirectServeResource, return_html_error +from synapse.http.server import DirectServeHtmlResource, return_html_error -class SAML2ResponseResource(DirectServeResource): +class SAML2ResponseResource(DirectServeHtmlResource): """A Twisted web resource which handles the SAML response""" isLeaf = 1 diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py new file mode 100644 index 0000000000..62d36c2906 --- /dev/null +++ b/tests/http/test_additional_resource.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.http.additional_resource import AdditionalResource +from synapse.http.server import respond_with_json + +from tests.unittest import HomeserverTestCase + + +class _AsyncTestCustomEndpoint: + def __init__(self, config, module_api): + pass + + async def handle_request(self, request): + respond_with_json(request, 200, {"some_key": "some_value_async"}) + + +class _SyncTestCustomEndpoint: + def __init__(self, config, module_api): + pass + + async def handle_request(self, request): + respond_with_json(request, 200, {"some_key": "some_value_sync"}) + + +class AdditionalResourceTests(HomeserverTestCase): + """Very basic tests that `AdditionalResource` works correctly with sync + and async handlers. + """ + + def test_async(self): + handler = _AsyncTestCustomEndpoint({}, None).handle_request + self.resource = AdditionalResource(self.hs, handler) + + request, channel = self.make_request("GET", "/") + self.render(request) + + self.assertEqual(request.code, 200) + self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) + + def test_sync(self): + handler = _SyncTestCustomEndpoint({}, None).handle_request + self.resource = AdditionalResource(self.hs, handler) + + request, channel = self.make_request("GET", "/") + self.render(request) + + self.assertEqual(request.code, 200) + self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/test_server.py b/tests/test_server.py index 3f6f468e5b..030f58cbdc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -24,12 +24,7 @@ from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.config.server import parse_listener_def -from synapse.http.server import ( - DirectServeResource, - JsonResource, - OptionsResource, - wrap_html_request_handler, -) +from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource from synapse.http.site import SynapseSite, logger from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock @@ -256,12 +251,11 @@ class OptionsResourceTests(unittest.TestCase): class WrapHtmlRequestHandlerTests(unittest.TestCase): - class TestResource(DirectServeResource): + class TestResource(DirectServeHtmlResource): callback = None - @wrap_html_request_handler async def _async_render_GET(self, request): - return await self.callback(request) + await self.callback(request) def setUp(self): self.reactor = ThreadedMemoryReactorClock() -- cgit 1.5.1 From 62b1ce85398f52e7d6137e77083294d0c90af459 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Sun, 5 Jul 2020 16:32:02 +0100 Subject: isort 5 compatibility (#7786) The CI appears to use the latest version of isort, which is a problem when isort gets a major version bump. Rather than try to pin the version, I've done the necessary to make isort5 happy with synapse. --- changelog.d/7786.misc | 1 + scripts-dev/check_signature.py | 2 +- scripts-dev/lint.sh | 2 +- setup.cfg | 1 - synapse/api/auth.py | 3 +-- synapse/config/__main__.py | 1 + synapse/config/emailconfig.py | 3 +-- synapse/handlers/auth.py | 3 +-- synapse/handlers/cas_handler.py | 3 +-- synapse/logging/opentracing.py | 4 ++-- synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/handler.py | 4 ++-- synapse/replication/tcp/streams/events.py | 2 -- synapse/rest/media/v1/thumbnailer.py | 3 +-- synapse/secrets.py | 3 +-- synapse/storage/data_stores/main/events.py | 3 +-- synapse/storage/data_stores/main/ui_auth.py | 2 +- synapse/storage/types.py | 2 -- synapse/types.py | 2 +- tests/handlers/test_e2e_keys.py | 4 +--- tests/rest/media/v1/test_media_storage.py | 4 +--- tests/test_utils/event_injection.py | 2 -- tox.ini | 4 ++-- 23 files changed, 22 insertions(+), 38 deletions(-) create mode 100644 changelog.d/7786.misc (limited to 'tests') diff --git a/changelog.d/7786.misc b/changelog.d/7786.misc new file mode 100644 index 0000000000..27af2681dc --- /dev/null +++ b/changelog.d/7786.misc @@ -0,0 +1 @@ +Update linting scripts and codebase to be compatible with `isort` v5. diff --git a/scripts-dev/check_signature.py b/scripts-dev/check_signature.py index ecda103cf7..6755bc5282 100644 --- a/scripts-dev/check_signature.py +++ b/scripts-dev/check_signature.py @@ -2,9 +2,9 @@ import argparse import json import logging import sys -import urllib2 import dns.resolver +import urllib2 from signedjson.key import decode_verify_key_bytes, write_signing_keys from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 6f1ba22931..66b0568858 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -15,7 +15,7 @@ else fi echo "Linting these locations: $files" -isort -y -rc $files +isort $files python3 -m black $files ./scripts-dev/config-lint.sh flake8 $files diff --git a/setup.cfg b/setup.cfg index f2bca272e1..a32278ea8a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,6 @@ ignore=W503,W504,E203,E731,E501 [isort] line_length = 88 -not_skip = __init__.py sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER default_section=THIRDPARTY known_first_party = synapse diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 06ba6604f3..cb22508f4d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Optional @@ -22,7 +21,6 @@ from netaddr import IPAddress from twisted.internet import defer from twisted.web.server import Request -import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth from synapse.api.auth_blocking import AuthBlocking @@ -35,6 +33,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.logging import opentracing as opentracing from synapse.types import StateMap, UserID from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index fca35b008c..65043d5b5b 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -16,6 +16,7 @@ from synapse.config._base import ConfigError if __name__ == "__main__": import sys + from synapse.config.homeserver import HomeServerConfig action = sys.argv[1] diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index ca61214454..df08bcd1bc 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import print_function # This file can't be called email.py because if it is, we cannot: @@ -145,8 +144,8 @@ class EmailConfig(Config): or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL ): # make sure we can import the required deps - import jinja2 import bleach + import jinja2 # prevent unused warnings jinja2 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d713a06bf9..a162392e4c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import time import unicodedata @@ -24,7 +23,6 @@ import attr import bcrypt # type: ignore[import] import pymacaroons -import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import ( AuthError, @@ -45,6 +43,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID +from synapse.util import stringutils as stringutils from synapse.util.threepids import canonicalise_email from ._base import BaseHandler diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 76f213723a..d79ffefdb5 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -12,11 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import urllib -import xml.etree.ElementTree as ET from typing import Dict, Optional, Tuple +from xml.etree import ElementTree as ET from twisted.web.client import PartialDownloadError diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 1676771ef0..c6c0e623c1 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -164,7 +164,6 @@ Gotchas than one caller? Will all of those calling functions have be in a context with an active span? """ - import contextlib import inspect import logging @@ -180,8 +179,8 @@ from twisted.internet import defer from synapse.config import ConfigError if TYPE_CHECKING: - from synapse.server import HomeServer from synapse.http.site import SynapseRequest + from synapse.server import HomeServer # Helper class @@ -227,6 +226,7 @@ except ImportError: tags = _DummyTagNames try: from jaeger_client import Config as JaegerConfig + from synapse.logging.scopecontextmanager import LogContextScopeManager except ImportError: JaegerConfig = None # type: ignore diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index df29732f51..4985e40b1f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.server import HomeServer from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index e6a2e2598b..55b3b79008 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar @@ -149,10 +148,11 @@ class ReplicationCommandHandler: using TCP. """ if hs.config.redis.redis_enabled: + import txredisapi + from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) - import txredisapi logger.info( "Connecting to redis (host=%r port=%r)", diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f370390331..bdddb62ad6 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import heapq from collections import Iterable from typing import List, Tuple, Type @@ -22,7 +21,6 @@ import attr from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance - """Handling of the 'events' replication stream This stream contains rows of various types. Each row therefore contains a 'type' diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index c234ea7421..7126997134 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -12,11 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from io import BytesIO -import PIL.Image as Image +from PIL import Image as Image logger = logging.getLogger(__name__) diff --git a/synapse/secrets.py b/synapse/secrets.py index 0b327a0f82..5f43f81eb0 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py @@ -19,7 +19,6 @@ Injectable secrets module for Synapse. See https://docs.python.org/3/library/secrets.html#module-secrets for the API used in Python 3.6, and the API emulated in Python 2.7. """ - import sys # secrets is available since python 3.6 @@ -31,8 +30,8 @@ if sys.version_info[0:2] >= (3, 6): else: - import os import binascii + import os class Secrets(object): def token_bytes(self, nbytes=32): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index cfd24d2f06..b7bf3fbd9d 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging from collections import OrderedDict, namedtuple @@ -48,8 +47,8 @@ from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter if TYPE_CHECKING: - from synapse.storage.data_stores.main import DataStore from synapse.server import HomeServer + from synapse.storage.data_stores.main import DataStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py index ec2f38c373..4c044b1a15 100644 --- a/synapse/storage/data_stores/main/ui_auth.py +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -17,10 +17,10 @@ from typing import Any, Dict, Optional, Union import attr -import synapse.util.stringutils as stringutils from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.types import JsonDict +from synapse.util import stringutils as stringutils @attr.s diff --git a/synapse/storage/types.py b/synapse/storage/types.py index daff81c5ee..2d2b560e74 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,12 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, Iterable, Iterator, List, Tuple from typing_extensions import Protocol - """ Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ diff --git a/synapse/types.py b/synapse/types.py index acf60baddc..238b938064 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError if sys.version_info[:3] >= (3, 6, 0): from typing import Collection else: - from typing import Sized, Iterable, Container + from typing import Container, Iterable, Sized T_co = TypeVar("T_co", covariant=True) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 6c1dc72bd1..1acf287ca4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -14,11 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import mock -import signedjson.key as key -import signedjson.sign as sign +from signedjson import key as key, sign as sign from twisted.internet import defer diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 2ed9312d56..66fa5978b2 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import os import shutil import tempfile @@ -25,8 +23,8 @@ from urllib import parse from mock import Mock import attr -import PIL.Image as Image from parameterized import parameterized_class +from PIL import Image as Image from twisted.internet.defer import Deferred diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 431e9f8e5e..43297b530c 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Optional, Tuple import synapse.server @@ -25,7 +24,6 @@ from synapse.types import Collection from tests.test_utils import get_awaitable_result - """ Utility functions for poking events into the storage of the server under test. """ diff --git a/tox.ini b/tox.ini index ab6557f15e..1c042cb227 100644 --- a/tox.ini +++ b/tox.ini @@ -131,8 +131,8 @@ commands = [testenv:check_isort] skip_install = True -deps = isort -commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests scripts-dev scripts" +deps = isort==5.0.3 +commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts" [testenv:check-newsfragment] skip_install = True -- cgit 1.5.1 From 57feeab364325374b14ff67ac97c288983cc5cde Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 6 Jul 2020 11:43:41 +0100 Subject: Don't ignore `set_tweak` actions with no explicit `value`. (#7766) * Fix spec compliance; tweaks without values are valid (default to True, which is only concretely specified for `highlight`, but it seems only reasonable to generalise) * Changelog for 7766. * Add documentation to `tweaks_for_actions` May as well tidy up when I'm here. * Add a test for `tweaks_for_actions` --- changelog.d/7766.bugfix | 1 + synapse/push/push_rule_evaluator.py | 31 +++++++++++++++++++++++++++---- tests/push/test_push_rule_evaluator.py | 17 +++++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 changelog.d/7766.bugfix (limited to 'tests') diff --git a/changelog.d/7766.bugfix b/changelog.d/7766.bugfix new file mode 100644 index 0000000000..ec5ecd8055 --- /dev/null +++ b/changelog.d/7766.bugfix @@ -0,0 +1 @@ +Fix to not ignore `set_tweak` actions in Push Rules that have no `value`, as permitted by the specification. diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 8e0d3a416d..2d79ada189 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,7 +16,7 @@ import logging import re -from typing import Pattern +from typing import Any, Dict, List, Pattern, Union from synapse.events import EventBase from synapse.types import UserID @@ -72,13 +72,36 @@ def _test_ineq_condition(condition, number): return False -def tweaks_for_actions(actions): +def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: + """ + Converts a list of actions into a `tweaks` dict (which can then be passed to + the push gateway). + + This function ignores all actions other than `set_tweak` actions, and treats + absent `value`s as `True`, which agrees with the only spec-defined treatment + of absent `value`s (namely, for `highlight` tweaks). + + Args: + actions: list of actions + e.g. [ + {"set_tweak": "a", "value": "AAA"}, + {"set_tweak": "b", "value": "BBB"}, + {"set_tweak": "highlight"}, + "notify" + ] + + Returns: + dictionary of tweaks for those actions + e.g. {"a": "AAA", "b": "BBB", "highlight": True} + """ tweaks = {} for a in actions: if not isinstance(a, dict): continue - if "set_tweak" in a and "value" in a: - tweaks[a["set_tweak"]] = a["value"] + if "set_tweak" in a: + # value is allowed to be absent in which case the value assumed + # should be True. + tweaks[a["set_tweak"]] = a.get("value", True) return tweaks diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index af35d23aea..1f4b5ca2ac 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -15,6 +15,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent +from synapse.push import push_rule_evaluator from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent from tests import unittest @@ -84,3 +85,19 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): for body in (1, True, {"foo": "bar"}): evaluator = self._get_evaluator({"body": body}) self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + def test_tweaks_for_actions(self): + """ + This tests the behaviour of tweaks_for_actions. + """ + + actions = [ + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + "notify", + ] + + self.assertEqual( + push_rule_evaluator.tweaks_for_actions(actions), + {"sound": "default", "highlight": True}, + ) -- cgit 1.5.1 From 6d687ebba11c701698df1f3da6fccec3b486c25a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 6 Jul 2020 07:40:35 -0400 Subject: Convert the appservice handler to async/await. (#7775) --- changelog.d/7775.misc | 1 + synapse/appservice/api.py | 1 - synapse/handlers/appservice.py | 74 +++++++++++++++++---------------------- tests/handlers/test_appservice.py | 68 ++++++++++++++++++----------------- 4 files changed, 68 insertions(+), 76 deletions(-) create mode 100644 changelog.d/7775.misc (limited to 'tests') diff --git a/changelog.d/7775.misc b/changelog.d/7775.misc new file mode 100644 index 0000000000..af6fdb782f --- /dev/null +++ b/changelog.d/7775.misc @@ -0,0 +1 @@ +Convert the appserver handler to async/await. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index da9a5e86d4..f92bfb420b 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -98,7 +98,6 @@ class ApplicationServiceApi(SimpleHttpClient): if service.url is None: return False uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) - response = None try: response = yield self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 904c96eeec..92d4c6e16c 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -48,8 +48,7 @@ class ApplicationServicesHandler(object): self.current_max = 0 self.is_processing = False - @defer.inlineCallbacks - def notify_interested_services(self, current_id): + async def notify_interested_services(self, current_id): """Notifies (pushes) all application services interested in this event. Pushing is done asynchronously, so this method won't block for any @@ -74,7 +73,7 @@ class ApplicationServicesHandler(object): ( upper_bound, events, - ) = yield self.store.get_new_events_for_appservice( + ) = await self.store.get_new_events_for_appservice( self.current_max, limit ) @@ -85,10 +84,9 @@ class ApplicationServicesHandler(object): for event in events: events_by_room.setdefault(event.room_id, []).append(event) - @defer.inlineCallbacks - def handle_event(event): + async def handle_event(event): # Gather interested services - services = yield self._get_services_for_event(event) + services = await self._get_services_for_event(event) if len(services) == 0: return # no services need notifying @@ -96,9 +94,9 @@ class ApplicationServicesHandler(object): # query API for all services which match that user regex. # This needs to block as these user queries need to be # made BEFORE pushing the event. - yield self._check_user_exists(event.sender) + await self._check_user_exists(event.sender) if event.type == EventTypes.Member: - yield self._check_user_exists(event.state_key) + await self._check_user_exists(event.state_key) if not self.started_scheduler: @@ -115,17 +113,16 @@ class ApplicationServicesHandler(object): self.scheduler.submit_event_for_as(service, event) now = self.clock.time_msec() - ts = yield self.store.get_received_ts(event.event_id) + ts = await self.store.get_received_ts(event.event_id) synapse.metrics.event_processing_lag_by_event.labels( "appservice_sender" ).observe((now - ts) / 1000) - @defer.inlineCallbacks - def handle_room_events(events): + async def handle_room_events(events): for event in events: - yield handle_event(event) + await handle_event(event) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(handle_room_events, evs) @@ -135,10 +132,10 @@ class ApplicationServicesHandler(object): ) ) - yield self.store.set_appservice_last_pos(upper_bound) + await self.store.set_appservice_last_pos(upper_bound) now = self.clock.time_msec() - ts = yield self.store.get_received_ts(events[-1].event_id) + ts = await self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_positions.labels( "appservice_sender" @@ -161,8 +158,7 @@ class ApplicationServicesHandler(object): finally: self.is_processing = False - @defer.inlineCallbacks - def query_user_exists(self, user_id): + async def query_user_exists(self, user_id): """Check if any application service knows this user_id exists. Args: @@ -170,15 +166,14 @@ class ApplicationServicesHandler(object): Returns: True if this user exists on at least one application service. """ - user_query_services = yield self._get_services_for_user(user_id=user_id) + user_query_services = self._get_services_for_user(user_id=user_id) for user_service in user_query_services: - is_known_user = yield self.appservice_api.query_user(user_service, user_id) + is_known_user = await self.appservice_api.query_user(user_service, user_id) if is_known_user: return True return False - @defer.inlineCallbacks - def query_room_alias_exists(self, room_alias): + async def query_room_alias_exists(self, room_alias): """Check if an application service knows this room alias exists. Args: @@ -193,19 +188,18 @@ class ApplicationServicesHandler(object): s for s in services if (s.is_interested_in_alias(room_alias_str)) ] for alias_service in alias_query_services: - is_known_alias = yield self.appservice_api.query_alias( + is_known_alias = await self.appservice_api.query_alias( alias_service, room_alias_str ) if is_known_alias: # the alias exists now so don't query more ASes. - result = yield self.store.get_association_from_room_alias(room_alias) + result = await self.store.get_association_from_room_alias(room_alias) return result - @defer.inlineCallbacks - def query_3pe(self, kind, protocol, fields): - services = yield self._get_services_for_3pn(protocol) + async def query_3pe(self, kind, protocol, fields): + services = self._get_services_for_3pn(protocol) - results = yield make_deferred_yieldable( + results = await make_deferred_yieldable( defer.DeferredList( [ run_in_background( @@ -224,8 +218,7 @@ class ApplicationServicesHandler(object): return ret - @defer.inlineCallbacks - def get_3pe_protocols(self, only_protocol=None): + async def get_3pe_protocols(self, only_protocol=None): services = self.store.get_app_services() protocols = {} @@ -238,7 +231,7 @@ class ApplicationServicesHandler(object): if p not in protocols: protocols[p] = [] - info = yield self.appservice_api.get_3pe_protocol(s, p) + info = await self.appservice_api.get_3pe_protocol(s, p) if info is not None: protocols[p].append(info) @@ -263,8 +256,7 @@ class ApplicationServicesHandler(object): return protocols - @defer.inlineCallbacks - def _get_services_for_event(self, event): + async def _get_services_for_event(self, event): """Retrieve a list of application services interested in this event. Args: @@ -280,7 +272,7 @@ class ApplicationServicesHandler(object): # inside of a list comprehension anymore. interested_list = [] for s in services: - if (yield s.is_interested(event, self.store)): + if await s.is_interested(event, self.store): interested_list.append(s) return interested_list @@ -288,21 +280,20 @@ class ApplicationServicesHandler(object): def _get_services_for_user(self, user_id): services = self.store.get_app_services() interested_list = [s for s in services if (s.is_interested_in_user(user_id))] - return defer.succeed(interested_list) + return interested_list def _get_services_for_3pn(self, protocol): services = self.store.get_app_services() interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] - return defer.succeed(interested_list) + return interested_list - @defer.inlineCallbacks - def _is_unknown_user(self, user_id): + async def _is_unknown_user(self, user_id): if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. return False - user_info = yield self.store.get_user_by_id(user_id) + user_info = await self.store.get_user_by_id(user_id) if user_info: return False @@ -311,10 +302,9 @@ class ApplicationServicesHandler(object): service_list = [s for s in services if s.sender == user_id] return len(service_list) == 0 - @defer.inlineCallbacks - def _check_user_exists(self, user_id): - unknown_user = yield self._is_unknown_user(user_id) + async def _check_user_exists(self, user_id): + unknown_user = await self._is_unknown_user(user_id) if unknown_user: - exists = yield self.query_user_exists(user_id) + exists = await self.query_user_exists(user_id) return exists return True diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index ba7148ec01..ebabe9a7d6 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -32,10 +32,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api = Mock() self.mock_scheduler = Mock() hs = Mock() - hs.get_datastore = Mock(return_value=self.mock_store) - self.mock_store.get_received_ts.return_value = 0 - hs.get_application_service_api = Mock(return_value=self.mock_as_api) - hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler) + hs.get_datastore.return_value = self.mock_store + self.mock_store.get_received_ts.return_value = defer.succeed(0) + self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None) + hs.get_application_service_api.return_value = self.mock_as_api + hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() self.handler = ApplicationServicesHandler(hs) @@ -48,18 +49,18 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice(is_interested=False), ] - self.mock_store.get_app_services = Mock(return_value=services) - self.mock_store.get_user_by_id = Mock(return_value=[]) + self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_store.get_app_services.return_value = services + self.mock_store.get_user_by_id.return_value = defer.succeed([]) event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) self.mock_store.get_new_events_for_appservice.side_effect = [ - (0, [event]), - (0, []), + defer.succeed((0, [event])), + defer.succeed((0, [])), ] - self.mock_as_api.push = Mock() - yield self.handler.notify_interested_services(0) + yield defer.ensureDeferred(self.handler.notify_interested_services(0)) self.mock_scheduler.submit_event_for_as.assert_called_once_with( interested_service, event ) @@ -68,36 +69,34 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_query_user_exists_unknown_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] - services[0].is_interested_in_user = Mock(return_value=True) - self.mock_store.get_app_services = Mock(return_value=services) - self.mock_store.get_user_by_id = Mock(return_value=None) + services[0].is_interested_in_user.return_value = True + self.mock_store.get_app_services.return_value = services + self.mock_store.get_user_by_id.return_value = defer.succeed(None) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.push = Mock() - self.mock_as_api.query_user = Mock() + self.mock_as_api.query_user.return_value = defer.succeed(True) self.mock_store.get_new_events_for_appservice.side_effect = [ - (0, [event]), - (0, []), + defer.succeed((0, [event])), + defer.succeed((0, [])), ] - yield self.handler.notify_interested_services(0) + yield defer.ensureDeferred(self.handler.notify_interested_services(0)) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @defer.inlineCallbacks def test_query_user_exists_known_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] - services[0].is_interested_in_user = Mock(return_value=True) - self.mock_store.get_app_services = Mock(return_value=services) - self.mock_store.get_user_by_id = Mock(return_value={"name": user_id}) + services[0].is_interested_in_user.return_value = True + self.mock_store.get_app_services.return_value = services + self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id}) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.push = Mock() - self.mock_as_api.query_user = Mock() + self.mock_as_api.query_user.return_value = defer.succeed(True) self.mock_store.get_new_events_for_appservice.side_effect = [ - (0, [event]), - (0, []), + defer.succeed((0, [event])), + defer.succeed((0, [])), ] - yield self.handler.notify_interested_services(0) + yield defer.ensureDeferred(self.handler.notify_interested_services(0)) self.assertFalse( self.mock_as_api.query_user.called, "query_user called when it shouldn't have been.", @@ -107,7 +106,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_query_room_alias_exists(self): room_alias_str = "#foo:bar" room_alias = Mock() - room_alias.to_string = Mock(return_value=room_alias_str) + room_alias.to_string.return_value = room_alias_str room_id = "!alpha:bet" servers = ["aperture"] @@ -118,12 +117,15 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice_alias(is_interested_in_alias=False), ] - self.mock_store.get_app_services = Mock(return_value=services) - self.mock_store.get_association_from_room_alias = Mock( - return_value=Mock(room_id=room_id, servers=servers) + self.mock_as_api.query_alias.return_value = defer.succeed(True) + self.mock_store.get_app_services.return_value = services + self.mock_store.get_association_from_room_alias.return_value = defer.succeed( + Mock(room_id=room_id, servers=servers) ) - result = yield self.handler.query_room_alias_exists(room_alias) + result = yield defer.ensureDeferred( + self.handler.query_room_alias_exists(room_alias) + ) self.mock_as_api.query_alias.assert_called_once_with( interested_service, room_alias_str @@ -133,14 +135,14 @@ class AppServiceHandlerTestCase(unittest.TestCase): def _mkservice(self, is_interested): service = Mock() - service.is_interested = Mock(return_value=is_interested) + service.is_interested.return_value = defer.succeed(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" return service def _mkservice_alias(self, is_interested_in_alias): service = Mock() - service.is_interested_in_alias = Mock(return_value=is_interested_in_alias) + service.is_interested_in_alias.return_value = is_interested_in_alias service.token = "mock_service_token" service.url = "mock_service_url" return service -- cgit 1.5.1 From 76dbd7b8d62beb10ee9304000fad62f65ba23876 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Jul 2020 14:20:40 +0100 Subject: Stop populating unused table `local_invites`. (#7793) This table is no longer used, so we may as well stop populating it. Removing it would prevent people rolling back to older releases of Synapse, so that can happen in a future release. --- changelog.d/7793.misc | 1 + synapse/storage/data_stores/main/events.py | 98 ++++++----------------- synapse/storage/data_stores/main/events_worker.py | 5 +- synapse/storage/data_stores/main/purge_events.py | 1 - tests/rest/admin/test_room.py | 1 - 5 files changed, 25 insertions(+), 81 deletions(-) create mode 100644 changelog.d/7793.misc (limited to 'tests') diff --git a/changelog.d/7793.misc b/changelog.d/7793.misc new file mode 100644 index 0000000000..2b6cfbe274 --- /dev/null +++ b/changelog.d/7793.misc @@ -0,0 +1 @@ +Stop populating unused table `local_invites`. diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index b7bf3fbd9d..a18317366c 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -27,12 +27,7 @@ from prometheus_client import Counter from twisted.internet import defer import synapse.metrics -from synapse.api.constants import ( - EventContentFields, - EventTypes, - Membership, - RelationTypes, -) +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase # noqa: F401 @@ -819,7 +814,6 @@ class PersistEventsStore: "event_reference_hashes", "event_search", "event_to_state_groups", - "local_invites", "state_events", "rejections", "redactions", @@ -1196,65 +1190,27 @@ class PersistEventsStore: (event.state_key,), ) - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. If the event is an - # outlier it is only current if its an "out of band membership", - # like a remote invite or a rejection of a remote invite. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_out_of_band_membership() - ) - is_mine = self.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self.db.simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - }, - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute( - sql, - ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - ), - ) - - # We also update the `local_current_membership` table with - # latest invite info. This will usually get updated by the - # `current_state_events` handling, unless its an outlier. - if event.internal_metadata.is_outlier(): - # This should only happen for out of band memberships, so - # we add a paranoia check. - assert event.internal_metadata.is_out_of_band_membership() - - self.db.simple_upsert_txn( - txn, - table="local_current_membership", - keyvalues={ - "room_id": event.room_id, - "user_id": event.state_key, - }, - values={ - "event_id": event.event_id, - "membership": event.membership, - }, - ) + # We update the local_current_membership table only if the event is + # "current", i.e., its something that has just happened. + # + # This will usually get updated by the `current_state_events` handling, + # unless its an outlier, and an outlier is only "current" if it's an "out of + # band membership", like a remote invite or a rejection of a remote invite. + if ( + self.is_mine_id(event.state_key) + and not backfilled + and event.internal_metadata.is_outlier() + and event.internal_metadata.is_out_of_band_membership() + ): + self.db.simple_upsert_txn( + txn, + table="local_current_membership", + keyvalues={"room_id": event.room_id, "user_id": event.state_key}, + values={ + "event_id": event.event_id, + "membership": event.membership, + }, + ) def _handle_event_relations(self, txn, event): """Handles inserting relation data during peristence of events @@ -1591,16 +1547,8 @@ class PersistEventsStore: create a leave event for it. """ - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - def f(txn, stream_ordering): - txn.execute(sql, (stream_ordering, True, room_id, user_id)) - - # We also clear this entry from `local_current_membership`. + # Clear this entry from `local_current_membership`. # Ideally we'd point to a leave event, but we don't have one, so # nevermind. self.db.simple_delete_txn( diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 47a3e63589..01cad7d4fa 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -82,10 +82,7 @@ class EventsWorkerStore(SQLBaseStore): # We are the process in charge of generating stream ids for events, # so instantiate ID generators based on the database self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - extra_tables=[("local_invites", "stream_id")], + db_conn, "events", "stream_ordering", ) self._backfill_id_gen = StreamIdGenerator( db_conn, diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py index a93e1ef198..6546569139 100644 --- a/synapse/storage/data_stores/main/purge_events.py +++ b/synapse/storage/data_stores/main/purge_events.py @@ -361,7 +361,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "event_push_summary", "pusher_throttle", "group_summary_rooms", - "local_invites", "room_account_data", "room_tags", "local_current_membership", diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 54cd24bf64..ae6d05a043 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -213,7 +213,6 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "event_push_summary", "pusher_throttle", "group_summary_rooms", - "local_invites", "room_account_data", "room_tags", # "state_groups", # Current impl leaves orphaned state groups around. -- cgit 1.5.1 From 67593b17287ae5e412c3d30db64d006d3b55349b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 8 Jul 2020 17:51:56 +0100 Subject: Add `HomeServer.signing_key` property (#7805) ... instead of duplicating `config.signing_key[0]` everywhere --- changelog.d/7805.misc | 1 + synapse/events/builder.py | 2 +- synapse/federation/federation_client.py | 2 +- synapse/groups/attestations.py | 2 +- synapse/groups/groups_server.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/groups_local.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/server.py | 2 ++ tests/storage/test_base.py | 7 ++----- 10 files changed, 12 insertions(+), 12 deletions(-) create mode 100644 changelog.d/7805.misc (limited to 'tests') diff --git a/changelog.d/7805.misc b/changelog.d/7805.misc new file mode 100644 index 0000000000..cbae08774a --- /dev/null +++ b/changelog.d/7805.misc @@ -0,0 +1 @@ +Add `signing_key` property to `HomeServer` to save code duplication. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index a0c4a40c27..92aadfe7ef 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -162,7 +162,7 @@ class EventBuilderFactory(object): def __init__(self, hs): self.clock = hs.get_clock() self.hostname = hs.hostname - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key self.store = hs.get_datastore() self.state = hs.get_state_handler() diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 687cd841ac..07d41ec03f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -87,7 +87,7 @@ class FederationClient(FederationBase): self.transport_layer = hs.get_federation_transport_client() self.hostname = hs.hostname - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key self._get_pdu_cache = ExpiringCache( cache_name="get_pdu_cache", diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 27b0c02655..dab13c243f 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -70,7 +70,7 @@ class GroupAttestationSigning(object): self.keyring = hs.get_keyring() self.clock = hs.get_clock() self.server_name = hs.hostname - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key @defer.inlineCallbacks def verify_attestation(self, attestation, group_id, user_id, server_name=None): diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 8db8ab1b7b..8cb922ddc7 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -41,7 +41,7 @@ class GroupsServerWorkerHandler(object): self.clock = hs.get_clock() self.keyring = hs.get_keyring() self.is_mine_id = hs.is_mine_id - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key self.server_name = hs.hostname self.attestations = hs.get_groups_attestation_signing() self.transport_client = hs.get_federation_transport_client() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b5aaa244dd..ca7da42a3f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1567,7 +1567,7 @@ class FederationHandler(BaseHandler): room_version, event.get_pdu_json(), self.hs.hostname, - self.hs.config.signing_key[0], + self.hs.signing_key, ) ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 7cb106e365..ecdb12a7bf 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -70,7 +70,7 @@ class GroupsLocalWorkerHandler(object): self.clock = hs.get_clock() self.keyring = hs.get_keyring() self.is_mine_id = hs.is_mine_id - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key self.server_name = hs.hostname self.notifier = hs.get_notifier() self.attestations = hs.get_groups_attestation_signing() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 58aed5fd96..148eeb19dc 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -176,7 +176,7 @@ class MatrixFederationHttpClient(object): def __init__(self, hs, tls_client_options_factory): self.hs = hs - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key self.server_name = hs.hostname real_reactor = hs.get_reactor() diff --git a/synapse/server.py b/synapse/server.py index fe94836a2c..6acce2e23f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -232,6 +232,8 @@ class HomeServer(object): self._reactor = reactor self.hostname = hostname + # the key we use to sign events and requests + self.signing_key = config.key.signing_key[0] self.config = config self._building = {} self._listening_services = [] diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 278961c331..b589506c60 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -25,7 +25,7 @@ from synapse.storage.database import Database from synapse.storage.engines import create_engine from tests import unittest -from tests.utils import TestHomeServer +from tests.utils import TestHomeServer, default_config class SQLBaseStoreTestCase(unittest.TestCase): @@ -49,10 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.db_pool.runWithConnection = runWithConnection - config = Mock() - config._disable_native_upserts = True - config.caches = Mock() - config.caches.event_cache_size = 1 + config = default_config(name="test", parse=True) hs = TestHomeServer("test", config=config) sqlite_config = {"name": "sqlite3"} -- cgit 1.5.1 From 38e1fac8861f12b707609da06008695a05aaf21c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Jul 2020 09:52:58 -0400 Subject: Fix some spelling mistakes / typos. (#7811) --- changelog.d/7811.misc | 1 + synapse/api/auth.py | 2 +- synapse/config/emailconfig.py | 2 +- synapse/federation/federation_client.py | 4 ++-- synapse/federation/federation_server.py | 6 +++--- synapse/federation/send_queue.py | 2 +- synapse/federation/sender/per_destination_queue.py | 4 ++-- synapse/federation/transport/client.py | 2 +- synapse/federation/transport/server.py | 4 ++-- synapse/notifier.py | 2 +- synapse/replication/http/_base.py | 4 ++-- synapse/replication/tcp/__init__.py | 2 +- synapse/replication/tcp/commands.py | 2 +- synapse/replication/tcp/protocol.py | 2 +- synapse/replication/tcp/redis.py | 2 +- synapse/replication/tcp/streams/events.py | 2 +- synapse/streams/config.py | 4 ++-- synapse/streams/events.py | 2 +- synapse/util/__init__.py | 2 +- synapse/util/async_helpers.py | 2 +- synapse/util/caches/descriptors.py | 2 +- synapse/util/distributor.py | 2 +- synapse/util/patch_inline_callbacks.py | 2 +- synapse/util/retryutils.py | 4 ++-- synapse/visibility.py | 4 ++-- tests/crypto/test_keyring.py | 2 +- tests/rest/client/test_retention.py | 2 +- tests/rest/client/v1/test_presence.py | 2 +- tests/rest/client/v2_alpha/test_relations.py | 2 +- tests/test_mau.py | 2 +- tests/util/test_logcontext.py | 4 ++-- 31 files changed, 41 insertions(+), 40 deletions(-) create mode 100644 changelog.d/7811.misc (limited to 'tests') diff --git a/changelog.d/7811.misc b/changelog.d/7811.misc new file mode 100644 index 0000000000..d907bba4df --- /dev/null +++ b/changelog.d/7811.misc @@ -0,0 +1 @@ +Fix various spelling errors in comments and log lines. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index cb22508f4d..40dc62ef6c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -537,7 +537,7 @@ class Auth(object): # Currently we ignore the `for_verification` flag even though there are # some situations where we can drop particular auth events when adding # to the event's `auth_events` (e.g. joins pointing to previous joins - # when room is publically joinable). Dropping event IDs has the + # when room is publicly joinable). Dropping event IDs has the # advantage that the auth chain for the room grows slower, but we use # the auth chain in state resolution v2 to order events, which means # care must be taken if dropping events to ensure that it doesn't diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index df08bcd1bc..b1dc7ad502 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -72,7 +72,7 @@ class EmailConfig(Config): template_dir = email_config.get("template_dir") # we need an absolute path, because we change directory after starting (and - # we don't yet know what auxilliary templates like mail.css we will need). + # we don't yet know what auxiliary templates like mail.css we will need). # (Note that loading as package_resources with jinja.PackageLoader doesn't # work for the same reason.) if not template_dir: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 07d41ec03f..a37cc9cb4a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -245,7 +245,7 @@ class FederationClient(FederationBase): event_id: event to fetch room_version: version of the room outlier: Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitary point in the context as opposed to part + it's from an arbitrary point in the context as opposed to part of the current block of PDUs. Defaults to `False` timeout: How long to try (in ms) each destination for before moving to the next destination. None indicates no timeout. @@ -351,7 +351,7 @@ class FederationClient(FederationBase): outlier: bool = False, include_none: bool = False, ) -> List[EventBase]: - """Takes a list of PDUs and checks the signatures and hashs of each + """Takes a list of PDUs and checks the signatures and hashes of each one. If a PDU fails its signature check then we check if we have it in the database and if not then request if from the originating server of that PDU. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e704cf2f44..86051decd4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -717,7 +717,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # server name is a literal IP allow_ip_literals = acl_event.content.get("allow_ip_literals", True) if not isinstance(allow_ip_literals, bool): - logger.warning("Ignorning non-bool allow_ip_literals flag") + logger.warning("Ignoring non-bool allow_ip_literals flag") allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. @@ -731,7 +731,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # next, check the deny list deny = acl_event.content.get("deny", []) if not isinstance(deny, (list, tuple)): - logger.warning("Ignorning non-list deny ACL %s", deny) + logger.warning("Ignoring non-list deny ACL %s", deny) deny = [] for e in deny: if _acl_entry_matches(server_name, e): @@ -741,7 +741,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # then the allow list. allow = acl_event.content.get("allow", []) if not isinstance(allow, (list, tuple)): - logger.warning("Ignorning non-list allow ACL %s", allow) + logger.warning("Ignoring non-list allow ACL %s", allow) allow = [] for e in allow: if _acl_entry_matches(server_name, e): diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 6bbd762681..860b03f7b9 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -359,7 +359,7 @@ class BaseFederationRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = "" # Unique string that ids the type. Must be overriden in sub classes. + TypeId = "" # Unique string that ids the type. Must be overridden in sub classes. @staticmethod def from_data(data): diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 4e698981a4..12966e239b 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -119,7 +119,7 @@ class PerDestinationQueue(object): ) def send_pdu(self, pdu: EventBase, order: int) -> None: - """Add a PDU to the queue, and start the transmission loop if neccessary + """Add a PDU to the queue, and start the transmission loop if necessary Args: pdu: pdu to send @@ -129,7 +129,7 @@ class PerDestinationQueue(object): self.attempt_new_transaction() def send_presence(self, states: Iterable[UserPresenceState]) -> None: - """Add presence updates to the queue. Start the transmission loop if neccessary. + """Add presence updates to the queue. Start the transmission loop if necessary. Args: states: presence to send diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 9f99311419..cfdf23d366 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -746,7 +746,7 @@ class TransportLayerClient(object): def remove_user_from_group( self, destination, group_id, requester_user_id, user_id, content ): - """Remove a user fron a group + """Remove a user from a group """ path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index bfb7831a02..d1bac318e7 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -109,7 +109,7 @@ class Authenticator(object): self.server_name = hs.hostname self.store = hs.get_datastore() self.federation_domain_whitelist = hs.config.federation_domain_whitelist - self.notifer = hs.get_notifier() + self.notifier = hs.get_notifier() self.replication_client = None if hs.config.worker.worker_app: @@ -175,7 +175,7 @@ class Authenticator(object): await self.store.set_destination_retry_timings(origin, None, 0, 0) # Inform the relevant places that the remote server is back up. - self.notifer.notify_remote_server_up(origin) + self.notifier.notify_remote_server_up(origin) if self.replication_client: # If we're on a worker we try and inform master about this. The # replication client doesn't hook into the notifier to avoid diff --git a/synapse/notifier.py b/synapse/notifier.py index 87c120a59c..bd41f77852 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -83,7 +83,7 @@ class _NotifierUserStream(object): self.current_token = current_token # The last token for which we should wake up any streams that have a - # token that comes before it. This gets updated everytime we get poked. + # token that comes before it. This gets updated every time we get poked. # We start it at the current token since if we get any streams # that have a token from before we have no idea whether they should be # woken up or not, so lets just wake them up. diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 0843d28d4b..fb0dd04f88 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -92,11 +92,11 @@ class ReplicationEndpoint(object): # assert here that sub classes don't try and use the name. assert ( "instance_name" not in self.PATH_ARGS - ), "`instance_name` is a reserved paramater name" + ), "`instance_name` is a reserved parameter name" assert ( "instance_name" not in signature(self.__class__._serialize_payload).parameters - ), "`instance_name` is a reserved paramater name" + ), "`instance_name` is a reserved parameter name" assert self.METHOD in ("PUT", "POST", "GET") diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py index 523a1358d4..1b8718b11d 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py @@ -25,7 +25,7 @@ Structure of the module: * command.py - the definitions of all the valid commands * protocol.py - the TCP protocol classes * resource.py - handles streaming stream updates to replications - * streams/ - the definitons of all the valid streams + * streams/ - the definitions of all the valid streams The general interaction of the classes are: diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 0f453ff0a8..ccc7f1f0d1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -47,7 +47,7 @@ class Command(metaclass=abc.ABCMeta): @abc.abstractmethod def to_line(self) -> str: - """Serialises the comamnd for the wire. Does not include the command + """Serialises the command for the wire. Does not include the command prefix. """ diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 4198eece71..ca47f5cc88 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -317,7 +317,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def _queue_command(self, cmd): """Queue the command until the connection is ready to write to again. """ - logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) if len(self.pending_commands) > self.max_line_buffer: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index e776b63183..0a7e7f67be 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -177,7 +177,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): Args: hs outbound_redis_connection: A connection to redis that will be used to - send outbound commands (this is seperate to the redis connection + send outbound commands (this is separate to the redis connection used to subscribe). """ diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index bdddb62ad6..1c2a4cce7f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -62,7 +62,7 @@ class BaseEventsStreamRow(object): Specifies how to identify, serialize and deserialize the different types. """ - # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overridden in sub classes. TypeId = None # type: str @classmethod diff --git a/synapse/streams/config.py b/synapse/streams/config.py index cd56cd91ed..ca7c16ff65 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -68,13 +68,13 @@ class PaginationConfig(object): elif from_tok: from_tok = StreamToken.from_string(from_tok) except Exception: - raise SynapseError(400, "'from' paramater is invalid") + raise SynapseError(400, "'from' parameter is invalid") try: if to_tok: to_tok = StreamToken.from_string(to_tok) except Exception: - raise SynapseError(400, "'to' paramater is invalid") + raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index fcd2aaa9c9..5d3eddcfdc 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -68,7 +68,7 @@ class EventSources(object): The returned token does not have the current values for fields other than `room`, since they are not used during pagination. - Retuns: + Returns: Deferred[StreamToken] """ token = StreamToken( diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 60f0de70f7..c63256d3bd 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -55,7 +55,7 @@ class Clock(object): return self._reactor.seconds() def time_msec(self): - """Returns the current system time in miliseconds since epoch.""" + """Returns the current system time in milliseconds since epoch.""" return int(self.time() * 1000) def looping_call(self, f, msec, *args, **kwargs): diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 65abf0846e..f562770922 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -352,7 +352,7 @@ class ReadWriteLock(object): # resolved when they release the lock). # # Read: We know its safe to acquire a read lock when the latest writer has - # been resolved. The new reader is appeneded to the list of latest readers. + # been resolved. The new reader is appended to the list of latest readers. # # Write: We know its safe to acquire the write lock when both the latest # writers and readers have been resolved. The new writer replaces the latest diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 64f35fc288..9b09c08b89 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -516,7 +516,7 @@ class CacheListDescriptor(_CacheDescriptorBase): """ Args: orig (function) - cached_method_name (str): The name of the chached method. + cached_method_name (str): The name of the cached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int): number of positional arguments (excluding ``self``, but including list_name) to use as cache keys. Defaults to all diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 45af8d3eeb..da20523b70 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -39,7 +39,7 @@ class Distributor(object): Signals are named simply by strings. TODO(paul): It would be nice to give signals stronger object identities, - so we can attach metadata, docstrings, detect typoes, etc... But this + so we can attach metadata, docstrings, detect typos, etc... But this model will do for today. """ diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 2605f3c65b..54c046b6e1 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -192,7 +192,7 @@ def _check_yield_points(f: Callable, changes: List[str]): result = yield d except Exception: # this will fish an earlier Failure out of the stack where possible, and - # thus is preferable to passing in an exeception to the Failure + # thus is preferable to passing in an exception to the Failure # constructor, since it results in less stack-mangling. result = Failure() diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index af69587196..8794317caa 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -22,7 +22,7 @@ from synapse.api.errors import CodeMessageException logger = logging.getLogger(__name__) -# the intial backoff, after the first transaction fails +# the initial backoff, after the first transaction fails MIN_RETRY_INTERVAL = 10 * 60 * 1000 # how much we multiply the backoff by after each subsequent fail @@ -174,7 +174,7 @@ class RetryDestinationLimiter(object): # has been decommissioned. # If we get a 401, then we should probably back off since they # won't accept our requests for at least a while. - # 429 is us being aggresively rate limited, so lets rate limit + # 429 is us being aggressively rate limited, so lets rate limit # ourselves. if exc_val.code == 404 and self.backoff_on_404: valid_err_code = False diff --git a/synapse/visibility.py b/synapse/visibility.py index 3dfd4af26c..0f042c5696 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -319,7 +319,7 @@ def filter_events_for_server( return True # Lets check to see if all the events have a history visibility - # of "shared" or "world_readable". If thats the case then we don't + # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), @@ -335,7 +335,7 @@ def filter_events_for_server( visibility_ids.add(hist) # If we failed to find any history visibility events then the default - # is "shared" visiblity. + # is "shared" visibility. if not visibility_ids: all_open = True else: diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 70c8e72303..f9ce609923 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") self.failureResultOf(d, SynapseError) - # should suceed on a signed object + # should succeed on a signed object d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") # self.assertFalse(d.called) self.get_success(d) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 95475bb651..e54ffea150 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -126,7 +126,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): events.append(self.get_success(store.get_event(valid_event_id))) - # Advance the time by anothe 2 days. After this, the first event should be + # Advance the time by another 2 days. After this, the first event should be # outdated but not the second one. self.reactor.advance(one_day_ms * 2 / 1000) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 0fdff79aa7..3c66255dac 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -60,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def test_put_presence_disabled(self): """ - PUT to the status endpoint with use_presence disbled will NOT call + PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. """ self.hs.config.use_presence = False diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index fd641a7c2f..99c9f4e928 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -99,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) def test_basic_paginate_relations(self): - """Tests that calling pagination API corectly the latest relations. + """Tests that calling pagination API correctly the latest relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") self.assertEquals(200, channel.code, channel.json_body) diff --git a/tests/test_mau.py b/tests/test_mau.py index 49667ed7f4..654a6fa42d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -166,7 +166,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.do_sync_for_user(token5) self.do_sync_for_user(token6) - # But old user cant + # But old user can't with self.assertRaises(SynapseError) as cm: self.do_sync_for_user(token1) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 95301c013c..58ee918f65 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable(self): - # a function which retuns an incomplete deferred, but doesn't follow + # a function which returns an incomplete deferred, but doesn't follow # the synapse rules. def blocking_function(): d = defer.Deferred() @@ -183,7 +183,7 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable_with_await(self): - # an async function which retuns an incomplete coroutine, but doesn't + # an async function which returns an incomplete coroutine, but doesn't # follow the synapse rules. async def blocking_function(): -- cgit 1.5.1 From f299441cc67f31dcd47b8fdfda4a218bee9df9ba Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 10 Jul 2020 18:26:36 +0100 Subject: Add ability to shard the federation sender (#7798) --- changelog.d/7798.feature | 1 + docs/sample_config.yaml | 65 ++--- synapse/app/generic_worker.py | 59 ++--- synapse/config/federation.py | 129 ++++++++++ synapse/config/homeserver.py | 3 + synapse/config/server.py | 66 ----- synapse/federation/send_queue.py | 14 +- synapse/federation/sender/__init__.py | 48 +++- synapse/federation/sender/per_destination_queue.py | 22 ++ synapse/replication/tcp/commands.py | 10 +- synapse/replication/tcp/handler.py | 4 +- .../delta/58/10federation_pos_instance_name.sql | 22 ++ synapse/storage/data_stores/main/stream.py | 97 ++++++- tests/replication/test_federation_ack.py | 1 + tests/replication/test_federation_sender_shard.py | 286 +++++++++++++++++++++ 15 files changed, 670 insertions(+), 157 deletions(-) create mode 100644 changelog.d/7798.feature create mode 100644 synapse/config/federation.py create mode 100644 synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql create mode 100644 tests/replication/test_federation_sender_shard.py (limited to 'tests') diff --git a/changelog.d/7798.feature b/changelog.d/7798.feature new file mode 100644 index 0000000000..56ffaf0d4a --- /dev/null +++ b/changelog.d/7798.feature @@ -0,0 +1 @@ +Add experimental support for running multiple federation sender processes. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 164a104045..1a2d9fb153 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -118,38 +118,6 @@ pid_file: DATADIR/homeserver.pid # #enable_search: false -# Restrict federation to the following whitelist of domains. -# N.B. we recommend also firewalling your federation listener to limit -# inbound federation traffic as early as possible, rather than relying -# purely on this application-layer restriction. If not specified, the -# default is to whitelist everything. -# -#federation_domain_whitelist: -# - lon.example.com -# - nyc.example.com -# - syd.example.com - -# Prevent federation requests from being sent to the following -# blacklist IP address CIDR ranges. If this option is not specified, or -# specified with an empty list, no ip range blacklist will be enforced. -# -# As of Synapse v1.4.0 this option also affects any outbound requests to identity -# servers provided by user input. -# -# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly -# listed here, since they correspond to unroutable addresses.) -# -federation_ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' - # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -608,6 +576,39 @@ acme: +# Restrict federation to the following whitelist of domains. +# N.B. we recommend also firewalling your federation listener to limit +# inbound federation traffic as early as possible, rather than relying +# purely on this application-layer restriction. If not specified, the +# default is to whitelist everything. +# +#federation_domain_whitelist: +# - lon.example.com +# - nyc.example.com +# - syd.example.com + +# Prevent federation requests from being sent to the following +# blacklist IP address CIDR ranges. If this option is not specified, or +# specified with an empty list, no ip range blacklist will be enforced. +# +# As of Synapse v1.4.0 this option also affects any outbound requests to identity +# servers provided by user input. +# +# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly +# listed here, since they correspond to unroutable addresses.) +# +federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + + ## Caching ## # Caching can be configured through the following options. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index f6792d9fc8..e90695f026 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -511,25 +511,7 @@ class GenericWorkerSlavedStore( SearchWorkerStore, BaseSlavedStore, ): - def __init__(self, database, db_conn, hs): - super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs) - - # We pull out the current federation stream position now so that we - # always have a known value for the federation position in memory so - # that we don't have to bounce via a deferred once when we start the - # replication streams. - self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) - - def _get_federation_out_pos(self, db_conn): - sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?" - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, ("federation",)) - rows = txn.fetchall() - txn.close() - - return rows[0][0] if rows else -1 + pass class GenericWorkerServer(HomeServer): @@ -812,19 +794,11 @@ class FederationSenderHandler(object): self.federation_sender = hs.get_federation_sender() self._hs = hs - # if the worker is restarted, we want to pick up where we left off in - # the replication stream, so load the position from the database. - # - # XXX is this actually worthwhile? Whenever the master is restarted, we'll - # drop some rows anyway (which is mostly fine because we're only dropping - # typing and presence notifications). If the replication stream is - # unreliable, why do we do all this hoop-jumping to store the position in the - # database? See also https://github.com/matrix-org/synapse/issues/7535. - # - self.federation_position = self.store.federation_out_pos_startup + # Stores the latest position in the federation stream we've gotten up + # to. This is always set before we use it. + self.federation_position = None self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - self._last_ack = self.federation_position def on_start(self): # There may be some events that are persisted but haven't been sent, @@ -932,7 +906,6 @@ class FederationSenderHandler(object): # We ACK this token over replication so that the master can drop # its in memory queues self._hs.get_tcp_replication().send_federation_ack(current_position) - self._last_ack = current_position except Exception: logger.exception("Error updating federation stream position") @@ -960,7 +933,7 @@ def start(config_options): ) if config.worker_app == "synapse.app.appservice": - if config.notify_appservices: + if config.appservice.notify_appservices: sys.stderr.write( "\nThe appservices must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -970,13 +943,13 @@ def start(config_options): sys.exit(1) # Force the appservice to start since they will be disabled in the main config - config.notify_appservices = True + config.appservice.notify_appservices = True else: # For other worker types we force this to off. - config.notify_appservices = False + config.appservice.notify_appservices = False if config.worker_app == "synapse.app.pusher": - if config.start_pushers: + if config.server.start_pushers: sys.stderr.write( "\nThe pushers must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -986,13 +959,13 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.start_pushers = True + config.server.start_pushers = True else: # For other worker types we force this to off. - config.start_pushers = False + config.server.start_pushers = False if config.worker_app == "synapse.app.user_dir": - if config.update_user_directory: + if config.server.update_user_directory: sys.stderr.write( "\nThe update_user_directory must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -1002,13 +975,13 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.update_user_directory = True + config.server.update_user_directory = True else: # For other worker types we force this to off. - config.update_user_directory = False + config.server.update_user_directory = False if config.worker_app == "synapse.app.federation_sender": - if config.send_federation: + if config.federation.send_federation: sys.stderr.write( "\nThe send_federation must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -1018,10 +991,10 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.send_federation = True + config.federation.send_federation = True else: # For other worker types we force this to off. - config.send_federation = False + config.federation.send_federation = False synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts diff --git a/synapse/config/federation.py b/synapse/config/federation.py new file mode 100644 index 0000000000..7782ab4c9d --- /dev/null +++ b/synapse/config/federation.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hashlib import sha256 +from typing import List, Optional + +import attr +from netaddr import IPSet + +from ._base import Config, ConfigError + + +@attr.s +class ShardedFederationSendingConfig: + """Algorithm for choosing which federation sender instance is responsible + for which destionation host. + """ + + instances = attr.ib(type=List[str]) + + def should_send_to(self, instance_name: str, destination: str) -> bool: + """Whether this instance is responsible for sending transcations for + the given host. + """ + + # If multiple federation senders are not defined we always return true. + if not self.instances or len(self.instances) == 1: + return True + + # We shard by taking the hash, modulo it by the number of federation + # senders and then checking whether this instance matches the instance + # at that index. + # + # (Technically this introduces some bias and is not entirely uniform, but + # since the hash is so large the bias is ridiculously small). + dest_hash = sha256(destination.encode("utf8")).digest() + dest_int = int.from_bytes(dest_hash, byteorder="little") + remainder = dest_int % (len(self.instances)) + return self.instances[remainder] == instance_name + + +class FederationConfig(Config): + section = "federation" + + def read_config(self, config, **kwargs): + # Whether to send federation traffic out in this process. This only + # applies to some federation traffic, and so shouldn't be used to + # "disable" federation + self.send_federation = config.get("send_federation", True) + + federation_sender_instances = config.get("federation_sender_instances") or [] + self.federation_shard_config = ShardedFederationSendingConfig( + federation_sender_instances + ) + + # FIXME: federation_domain_whitelist needs sytests + self.federation_domain_whitelist = None # type: Optional[dict] + federation_domain_whitelist = config.get("federation_domain_whitelist", None) + + if federation_domain_whitelist is not None: + # turn the whitelist into a hash for speed of lookup + self.federation_domain_whitelist = {} + + for domain in federation_domain_whitelist: + self.federation_domain_whitelist[domain] = True + + self.federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", [] + ) + + # Attempt to create an IPSet from the given ranges + try: + self.federation_ip_range_blacklist = IPSet( + self.federation_ip_range_blacklist + ) + + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + except Exception as e: + raise ConfigError( + "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e + ) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + # Restrict federation to the following whitelist of domains. + # N.B. we recommend also firewalling your federation listener to limit + # inbound federation traffic as early as possible, rather than relying + # purely on this application-layer restriction. If not specified, the + # default is to whitelist everything. + # + #federation_domain_whitelist: + # - lon.example.com + # - nyc.example.com + # - syd.example.com + + # Prevent federation requests from being sent to the following + # blacklist IP address CIDR ranges. If this option is not specified, or + # specified with an empty list, no ip range blacklist will be enforced. + # + # As of Synapse v1.4.0 this option also affects any outbound requests to identity + # servers provided by user input. + # + # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly + # listed here, since they correspond to unroutable addresses.) + # + federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 264c274c52..8e93d31394 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -23,6 +23,7 @@ from .cas import CasConfig from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig +from .federation import FederationConfig from .groups import GroupsConfig from .jwt_config import JWTConfig from .key import KeyConfig @@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, TlsConfig, + FederationConfig, CacheConfig, DatabaseConfig, LoggingConfig, @@ -90,4 +92,5 @@ class HomeServerConfig(RootConfig): ThirdPartyRulesConfig, TracerConfig, RedisConfig, + FederationConfig, ] diff --git a/synapse/config/server.py b/synapse/config/server.py index 8204664883..b6afa642ca 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -23,7 +23,6 @@ from typing import Any, Dict, Iterable, List, Optional import attr import yaml -from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name @@ -136,11 +135,6 @@ class ServerConfig(Config): self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.public_baseurl = config.get("public_baseurl") - # Whether to send federation traffic out in this process. This only - # applies to some federation traffic, and so shouldn't be used to - # "disable" federation - self.send_federation = config.get("send_federation", True) - # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -263,34 +257,6 @@ class ServerConfig(Config): # due to resource constraints self.admin_contact = config.get("admin_contact", None) - # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist = None # type: Optional[dict] - federation_domain_whitelist = config.get("federation_domain_whitelist", None) - - if federation_domain_whitelist is not None: - # turn the whitelist into a hash for speed of lookup - self.federation_domain_whitelist = {} - - for domain in federation_domain_whitelist: - self.federation_domain_whitelist[domain] = True - - self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [] - ) - - # Attempt to create an IPSet from the given ranges - try: - self.federation_ip_range_blacklist = IPSet( - self.federation_ip_range_blacklist - ) - - # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - except Exception as e: - raise ConfigError( - "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e - ) - if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": self.public_baseurl += "/" @@ -743,38 +709,6 @@ class ServerConfig(Config): # #enable_search: false - # Restrict federation to the following whitelist of domains. - # N.B. we recommend also firewalling your federation listener to limit - # inbound federation traffic as early as possible, rather than relying - # purely on this application-layer restriction. If not specified, the - # default is to whitelist everything. - # - #federation_domain_whitelist: - # - lon.example.com - # - nyc.example.com - # - syd.example.com - - # Prevent federation requests from being sent to the following - # blacklist IP address CIDR ranges. If this option is not specified, or - # specified with an empty list, no ip range blacklist will be enforced. - # - # As of Synapse v1.4.0 this option also affects any outbound requests to identity - # servers provided by user input. - # - # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly - # listed here, since they correspond to unroutable addresses.) - # - federation_ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' - # List of ports that Synapse should listen on, their purpose and their # configuration. # diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 860b03f7b9..4fc9ff92e5 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -55,6 +55,11 @@ class FederationRemoteSendQueue(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id + # We may have multiple federation sender instances, so we need to track + # their positions separately. + self._sender_instances = hs.config.federation.federation_shard_config.instances + self._sender_positions = {} + # Pending presence map user_id -> UserPresenceState self.presence_map = {} # type: Dict[str, UserPresenceState] @@ -261,7 +266,14 @@ class FederationRemoteSendQueue(object): def get_current_token(self): return self.pos - 1 - def federation_ack(self, token): + def federation_ack(self, instance_name, token): + if self._sender_instances: + # If we have configured multiple federation sender instances we need + # to track their positions separately, and only clear the queue up + # to the token all instances have acked. + self._sender_positions[instance_name] = token + token = min(self._sender_positions.values()) + self._clear_queue_before_pos(token) async def get_replication_rows( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 464d7a41de..4b63a0755f 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -69,6 +69,9 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) + self._instance_name = hs.get_instance_name() + self._federation_shard_config = hs.config.federation.federation_shard_config + # map from destination to PerDestinationQueue self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] @@ -191,7 +194,13 @@ class FederationSender(object): ) return - destinations = set(destinations) + destinations = { + d + for d in destinations + if self._federation_shard_config.should_send_to( + self._instance_name, d + ) + } if send_on_behalf_of is not None: # If we are sending the event on behalf of another server @@ -322,7 +331,12 @@ class FederationSender(object): # Work out which remote servers should be poked and poke them. domains = yield self.state.get_current_hosts_in_room(room_id) - domains = [d for d in domains if d != self.server_name] + domains = [ + d + for d in domains + if d != self.server_name + and self._federation_shard_config.should_send_to(self._instance_name, d) + ] if not domains: return @@ -427,6 +441,10 @@ class FederationSender(object): for destination in destinations: if destination == self.server_name: continue + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + continue self._get_per_destination_queue(destination).send_presence(states) @measure_func("txnqueue._process_presence") @@ -441,6 +459,12 @@ class FederationSender(object): for destination in destinations: if destination == self.server_name: continue + + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + continue + self._get_per_destination_queue(destination).send_presence(states) def build_and_send_edu( @@ -462,6 +486,11 @@ class FederationSender(object): logger.info("Not sending EDU to ourselves") return + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + return + edu = Edu( origin=self.server_name, destination=destination, @@ -478,6 +507,11 @@ class FederationSender(object): edu: edu to send key: clobbering key for this edu """ + if not self._federation_shard_config.should_send_to( + self._instance_name, edu.destination + ): + return + queue = self._get_per_destination_queue(edu.destination) if key: queue.send_keyed_edu(edu, key) @@ -489,6 +523,11 @@ class FederationSender(object): logger.warning("Not sending device update to ourselves") return + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + return + self._get_per_destination_queue(destination).attempt_new_transaction() def wake_destination(self, destination: str): @@ -502,6 +541,11 @@ class FederationSender(object): logger.warning("Not waking up ourselves") return + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + return + self._get_per_destination_queue(destination).attempt_new_transaction() @staticmethod diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 12966e239b..6402136e8a 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -74,6 +74,20 @@ class PerDestinationQueue(object): self._clock = hs.get_clock() self._store = hs.get_datastore() self._transaction_manager = transaction_manager + self._instance_name = hs.get_instance_name() + self._federation_shard_config = hs.config.federation.federation_shard_config + + self._should_send_on_this_instance = True + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + # We don't raise an exception here to avoid taking out any other + # processing. We have a guard in `attempt_new_transaction` that + # ensure we don't start sending stuff. + logger.error( + "Create a per destination queue for %s on wrong worker", destination, + ) + self._should_send_on_this_instance = False self._destination = destination self.transmission_loop_running = False @@ -180,6 +194,14 @@ class PerDestinationQueue(object): logger.debug("TX [%s] Transaction already in progress", self._destination) return + if not self._should_send_on_this_instance: + # We don't raise an exception here to avoid taking out any other + # processing. + logger.error( + "Trying to start a transaction to %s on wrong worker", self._destination + ) + return + logger.debug("TX [%s] Starting transaction loop", self._destination) run_as_background_process( diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index ccc7f1f0d1..f33801f883 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -293,20 +293,22 @@ class FederationAckCommand(Command): Format:: - FEDERATION_ACK + FEDERATION_ACK """ NAME = "FEDERATION_ACK" - def __init__(self, token): + def __init__(self, instance_name, token): + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - return cls(int(line)) + instance_name, token = line.split(" ") + return cls(instance_name, int(token)) def to_line(self): - return str(self.token) + return "%s %s" % (self.instance_name, self.token) class RemovePusherCommand(Command): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 55b3b79008..80f5df60f9 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -238,7 +238,7 @@ class ReplicationCommandHandler: federation_ack_counter.inc() if self._federation_sender: - self._federation_sender.federation_ack(cmd.token) + self._federation_sender.federation_ack(cmd.instance_name, cmd.token) async def on_REMOVE_PUSHER( self, conn: AbstractConnection, cmd: RemovePusherCommand @@ -527,7 +527,7 @@ class ReplicationCommandHandler: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ - self.send_command(FederationAckCommand(token)) + self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int diff --git a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql new file mode 100644 index 0000000000..1cc2633aad --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We need to store the stream positions by instance in a sharded config world. +-- +-- We default to master as we want the column to be NOT NULL and we correctly +-- reset the instance name to match the config each time we start up. +ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master'; + +CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name); diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 379d758b5d..5e32c7aa1e 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -45,7 +45,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import Database, make_in_list_sql_clause from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -253,6 +253,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(StreamWorkerStore, self).__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self._send_federation = hs.should_send_federation() + self._federation_shard_config = hs.config.federation.federation_shard_config + + # If we're a process that sends federation we may need to reset the + # `federation_stream_position` table to match the current sharding + # config. We don't do this now as otherwise two processes could conflict + # during startup which would cause one to die. + self._need_to_reset_federation_stream_positions = self._send_federation + events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self.db.get_cache_dict( db_conn, @@ -793,22 +803,95 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events - def get_federation_out_pos(self, typ): - return self.db.simple_select_one_onecol( + async def get_federation_out_pos(self, typ: str) -> int: + if self._need_to_reset_federation_stream_positions: + await self.db.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", - keyvalues={"type": typ}, + keyvalues={"type": typ, "instance_name": self._instance_name}, desc="get_federation_out_pos", ) - def update_federation_out_pos(self, typ, stream_id): - return self.db.simple_update_one( + async def update_federation_out_pos(self, typ, stream_id): + if self._need_to_reset_federation_stream_positions: + await self.db.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db.simple_update_one( table="federation_stream_position", - keyvalues={"type": typ}, + keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) + def _reset_federation_positions_txn(self, txn): + """Fiddles with the `federation_stream_position` table to make it match + the configured federation sender instances during start up. + """ + + # The federation sender instances may have changed, so we need to + # massage the `federation_stream_position` table to have a row per type + # per instance sending federation. If there is a mismatch we update the + # table with the correct rows using the *minimum* stream ID seen. This + # may result in resending of events/EDUs to remote servers, but that is + # preferable to dropping them. + + if not self._send_federation: + return + + # Pull out the configured instances. If we don't have a shard config then + # we assume that we're the only instance sending. + configured_instances = self._federation_shard_config.instances + if not configured_instances: + configured_instances = [self._instance_name] + elif self._instance_name not in configured_instances: + return + + instances_in_table = self.db.simple_select_onecol_txn( + txn, + table="federation_stream_position", + keyvalues={}, + retcol="instance_name", + ) + + if set(instances_in_table) == set(configured_instances): + # Nothing to do + return + + sql = """ + SELECT type, MIN(stream_id) FROM federation_stream_position + GROUP BY type + """ + txn.execute(sql) + min_positions = dict(txn) # Map from type -> min position + + # Ensure we do actually have some values here + assert set(min_positions) == {"federation", "events"} + + sql = """ + DELETE FROM federation_stream_position + WHERE NOT (%s) + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "instance_name", configured_instances + ) + txn.execute(sql % (clause,), args) + + for typ, stream_id in min_positions.items(): + self.db.simple_upsert_txn( + txn, + table="federation_stream_position", + keyvalues={"type": typ, "instance_name": self._instance_name}, + values={"stream_id": stream_id}, + ) + def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id) diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 5448d9f0dc..23be1167a3 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer) + return hs def test_federation_ack_sent(self): diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py new file mode 100644 index 0000000000..519a2dc510 --- /dev/null +++ b/tests/replication/test_federation_sender_shard.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.app.generic_worker import GenericWorkerServer +from synapse.events.builder import EventBuilderFactory +from synapse.replication.http import streams +from synapse.replication.tcp.handler import ReplicationCommandHandler +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.rest.admin import register_servlets_for_client_rest_resource +from synapse.rest.client.v1 import login, room +from synapse.types import UserID + +from tests import unittest +from tests.server import FakeTransport + +logger = logging.getLogger(__name__) + + +class BaseStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests of the replication streams""" + + servlets = [ + streams.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # build a replication server + self.server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() + + store = hs.get_datastore() + self.database = store.db + + self.reactor.lookups["testserv"] = "1.2.3.4" + + def default_config(self): + conf = super().default_config() + conf["send_federation"] = False + return conf + + def make_worker_hs(self, extra_config={}): + config = self._get_worker_hs_config() + config.update(extra_config) + + mock_federation_client = Mock(spec=["put_json"]) + mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + + worker_hs = self.setup_test_homeserver( + http_client=mock_federation_client, + homeserverToUse=GenericWorkerServer, + config=config, + reactor=self.reactor, + ) + + store = worker_hs.get_datastore() + store.db._db_pool = self.database._db_pool + + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) + + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) + + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) + + return worker_hs + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.federation_sender" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump() + + def create_room_with_remote_server(self, user, token, remote_server="other_server"): + room = self.helper.create_room_as(user, tok=token) + store = self.hs.get_datastore() + federation = self.hs.get_handlers().federation_handler + + prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) + room_version = self.get_success(store.get_room_version(room)) + + factory = EventBuilderFactory(self.hs) + factory.hostname = remote_server + + user_id = UserID("user", remote_server).to_string() + + event_dict = { + "type": EventTypes.Member, + "state_key": user_id, + "content": {"membership": Membership.JOIN}, + "sender": user_id, + "room_id": room, + } + + builder = factory.for_room_version(room_version, event_dict) + join_event = self.get_success(builder.build(prev_event_ids)) + + self.get_success(federation.on_send_join_request(remote_server, join_event)) + self.replicate() + + return room + + +class FederationSenderTestCase(BaseStreamTestCase): + servlets = [ + login.register_servlets, + register_servlets_for_client_rest_resource, + room.register_servlets, + ] + + def test_send_event_single_sender(self): + """Test that using a single federation sender worker correctly sends a + new event. + """ + worker_hs = self.make_worker_hs({"send_federation": True}) + mock_client = worker_hs.get_http_client() + + user = self.register_user("user", "pass") + token = self.login("user", "pass") + + room = self.create_room_with_remote_server(user, token) + + mock_client.put_json.reset_mock() + + self.create_and_send_event(room, UserID.from_string(user)) + self.replicate() + + # Assert that the event was sent out over federation. + mock_client.put_json.assert_called() + self.assertEqual(mock_client.put_json.call_args[0][0], "other_server") + self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus")) + + def test_send_event_sharded(self): + """Test that using two federation sender workers correctly sends + new events. + """ + worker1 = self.make_worker_hs( + { + "send_federation": True, + "worker_name": "sender1", + "federation_sender_instances": ["sender1", "sender2"], + } + ) + mock_client1 = worker1.get_http_client() + + worker2 = self.make_worker_hs( + { + "send_federation": True, + "worker_name": "sender2", + "federation_sender_instances": ["sender1", "sender2"], + } + ) + mock_client2 = worker2.get_http_client() + + user = self.register_user("user2", "pass") + token = self.login("user2", "pass") + + sent_on_1 = False + sent_on_2 = False + for i in range(20): + server_name = "other_server_%d" % (i,) + room = self.create_room_with_remote_server(user, token, server_name) + mock_client1.reset_mock() + mock_client2.reset_mock() + + self.create_and_send_event(room, UserID.from_string(user)) + self.replicate() + + if mock_client1.put_json.called: + sent_on_1 = True + mock_client2.put_json.assert_not_called() + self.assertEqual(mock_client1.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("pdus")) + elif mock_client2.put_json.called: + sent_on_2 = True + mock_client1.put_json.assert_not_called() + self.assertEqual(mock_client2.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("pdus")) + else: + raise AssertionError( + "Expected send transaction from one or the other sender" + ) + + if sent_on_1 and sent_on_2: + break + + self.assertTrue(sent_on_1) + self.assertTrue(sent_on_2) + + def test_send_typing_sharded(self): + """Test that using two federation sender workers correctly sends + new typing EDUs. + """ + worker1 = self.make_worker_hs( + { + "send_federation": True, + "worker_name": "sender1", + "federation_sender_instances": ["sender1", "sender2"], + } + ) + mock_client1 = worker1.get_http_client() + + worker2 = self.make_worker_hs( + { + "send_federation": True, + "worker_name": "sender2", + "federation_sender_instances": ["sender1", "sender2"], + } + ) + mock_client2 = worker2.get_http_client() + + user = self.register_user("user3", "pass") + token = self.login("user3", "pass") + + typing_handler = self.hs.get_typing_handler() + + sent_on_1 = False + sent_on_2 = False + for i in range(20): + server_name = "other_server_%d" % (i,) + room = self.create_room_with_remote_server(user, token, server_name) + mock_client1.reset_mock() + mock_client2.reset_mock() + + self.get_success( + typing_handler.started_typing( + target_user=UserID.from_string(user), + auth_user=UserID.from_string(user), + room_id=room, + timeout=20000, + ) + ) + + self.replicate() + + if mock_client1.put_json.called: + sent_on_1 = True + mock_client2.put_json.assert_not_called() + self.assertEqual(mock_client1.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("edus")) + elif mock_client2.put_json.called: + sent_on_2 = True + mock_client1.put_json.assert_not_called() + self.assertEqual(mock_client2.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("edus")) + else: + raise AssertionError( + "Expected send transaction from one or the other sender" + ) + + if sent_on_1 and sent_on_2: + break + + self.assertTrue(sent_on_1) + self.assertTrue(sent_on_2) -- cgit 1.5.1 From 66a4af8d9627719a875c405c8c0f49b0056811b2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 10 Jul 2020 14:30:08 -0400 Subject: Do not use canonicaljson to magically handle decoding bytes from JSON. (#7802) --- changelog.d/7802.misc | 1 + synapse/api/errors.py | 6 ++---- synapse/federation/federation_server.py | 6 +++--- synapse/handlers/cas_handler.py | 2 +- synapse/http/client.py | 14 +++++++------- synapse/http/servlet.py | 14 ++------------ tests/rest/client/v1/test_login.py | 2 +- 7 files changed, 17 insertions(+), 28 deletions(-) create mode 100644 changelog.d/7802.misc (limited to 'tests') diff --git a/changelog.d/7802.misc b/changelog.d/7802.misc new file mode 100644 index 0000000000..d81f8875c5 --- /dev/null +++ b/changelog.d/7802.misc @@ -0,0 +1 @@ + Switch from simplejson to the standard library json. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 5305038c21..d5d4522336 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -15,13 +15,11 @@ # limitations under the License. """Contains exceptions and error codes.""" - +import json import logging from http import HTTPStatus from typing import Dict, List -from canonicaljson import json - from twisted.web import http logger = logging.getLogger(__name__) @@ -573,7 +571,7 @@ class HttpResponseException(CodeMessageException): # try to parse the body as json, to get better errcode/msg, but # default to M_UNKNOWN with the HTTP status as the error text try: - j = json.loads(self.response) + j = json.loads(self.response.decode("utf-8")) except ValueError: j = {} diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 86051decd4..2aab9c5f55 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -14,10 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import logging from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union -from canonicaljson import json from prometheus_client import Counter, Histogram from twisted.internet import defer @@ -526,9 +526,9 @@ class FederationServer(FederationBase): json_result = {} # type: Dict[str, Dict[str, dict]] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): + for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_bytes) + key_id: json.loads(json_str) } logger.info( diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index d79ffefdb5..786e608fa2 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -104,7 +104,7 @@ class CasHandler: return user, displayname def _parse_cas_response( - self, cas_response_body: str + self, cas_response_body: bytes ) -> Tuple[str, Dict[str, Optional[str]]]: """ Retrieve the user and other parameters from the CAS response. diff --git a/synapse/http/client.py b/synapse/http/client.py index 8743e9839d..505872ee90 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -13,13 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import logging import urllib from io import BytesIO import treq -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from netaddr import IPAddress from prometheus_client import Counter from zope.interface import implementer, provider @@ -371,7 +371,7 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body) + return json.loads(body.decode("utf-8")) else: raise HttpResponseException(response.code, response.phrase, body) @@ -412,7 +412,7 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body) + return json.loads(body.decode("utf-8")) else: raise HttpResponseException(response.code, response.phrase, body) @@ -441,7 +441,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) body = yield self.get_raw(uri, args, headers=headers) - return json.loads(body) + return json.loads(body.decode("utf-8")) @defer.inlineCallbacks def put_json(self, uri, json_body, args={}, headers=None): @@ -485,7 +485,7 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body) + return json.loads(body.decode("utf-8")) else: raise HttpResponseException(response.code, response.phrase, body) @@ -503,7 +503,7 @@ class SimpleHttpClient(object): header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the - HTTP body at text. + HTTP body as bytes. Raises: HttpResponseException on a non-2xx HTTP response. """ diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 13fcb408a6..3cabe9d02e 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,11 +14,9 @@ # limitations under the License. """ This module contains base REST classes for constructing REST servlets. """ - +import json import logging -from canonicaljson import json - from synapse.api.errors import Codes, SynapseError logger = logging.getLogger(__name__) @@ -214,16 +212,8 @@ def parse_json_value_from_request(request, allow_empty_body=False): if not content_bytes and allow_empty_body: return None - # Decode to Unicode so that simplejson will return Unicode strings on - # Python 2 - try: - content_unicode = content_bytes.decode("utf8") - except UnicodeDecodeError: - logger.warning("Unable to decode UTF-8") - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - try: - content = json.loads(content_unicode) + content = json.loads(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index fd97999956..2be7238b00 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -398,7 +398,7 @@ class CASTestCase(unittest.HomeserverTestCase): """ % cas_user_id - ) + ).encode("utf-8") mocked_http_client = Mock(spec=["get_raw"]) mocked_http_client.get_raw.side_effect = get_raw -- cgit 1.5.1 From 457096e6dfd2b5837f289366dd99e6d2f276d924 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 13 Jul 2020 13:31:46 -0400 Subject: Support handling registration requests across multiple client readers. (#7830) --- changelog.d/7830.feature | 1 + synapse/handlers/deactivate_account.py | 3 +- tests/replication/test_client_reader_shard.py | 133 ++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7830.feature create mode 100644 tests/replication/test_client_reader_shard.py (limited to 'tests') diff --git a/changelog.d/7830.feature b/changelog.d/7830.feature new file mode 100644 index 0000000000..b4f614084d --- /dev/null +++ b/changelog.d/7830.feature @@ -0,0 +1 @@ +Add support for handling registration requests across multiple client reader workers. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 2afb390a92..3e3e6bd475 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -40,7 +40,8 @@ class DeactivateAccountHandler(BaseHandler): # Start the user parter loop so it can resume parting users from rooms where # it left off (if it has work left to do). - hs.get_reactor().callWhenRunning(self._start_user_parting) + if hs.config.worker_app is None: + hs.get_reactor().callWhenRunning(self._start_user_parting) self._account_validity_enabled = hs.config.account_validity.enabled diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py new file mode 100644 index 0000000000..b7d753e0a3 --- /dev/null +++ b/tests/replication/test_client_reader_shard.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.api.constants import LoginType +from synapse.app.generic_worker import GenericWorkerServer +from synapse.http.server import JsonResource +from synapse.http.site import SynapseRequest +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.rest.client.v2_alpha import register + +from tests import unittest +from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker +from tests.server import FakeChannel, render + +logger = logging.getLogger(__name__) + + +class ClientReaderTestCase(unittest.HomeserverTestCase): + """Base class for tests of the replication streams""" + + servlets = [ + register.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # build a replication server + self.server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() + + store = hs.get_datastore() + self.database = store.db + + self.recaptcha_checker = DummyRecaptchaChecker(hs) + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker + + self.reactor.lookups["testserv"] = "1.2.3.4" + + def make_worker_hs(self, extra_config={}): + config = self._get_worker_hs_config() + config.update(extra_config) + + worker_hs = self.setup_test_homeserver( + homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor, + ) + + store = worker_hs.get_datastore() + store.db._db_pool = self.database._db_pool + + # Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource. + resource = JsonResource(self.hs) + + for servlet in self.servlets: + servlet(worker_hs, resource) + + # Essentially HomeserverTestCase.render. + def _render(request): + render(request, self.resource, self.reactor) + + return worker_hs, _render + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.client_reader" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def test_register_single_worker(self): + """Test that registration works when using a single client reader worker. + """ + _, worker_render = self.make_worker_hs() + + request_1, channel_1 = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) # type: SynapseRequest, FakeChannel + worker_render(request_1) + self.assertEqual(request_1.code, 401) + + # Grab the session + session = channel_1.json_body["session"] + + # also complete the dummy auth + request_2, channel_2 = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) # type: SynapseRequest, FakeChannel + worker_render(request_2) + self.assertEqual(request_2.code, 200) + + # We're given a registered user. + self.assertEqual(channel_2.json_body["user_id"], "@user:test") + + def test_register_multi_worker(self): + """Test that registration works when using multiple client reader workers. + """ + _, worker_render_1 = self.make_worker_hs() + _, worker_render_2 = self.make_worker_hs() + + request_1, channel_1 = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) # type: SynapseRequest, FakeChannel + worker_render_1(request_1) + self.assertEqual(request_1.code, 401) + + # Grab the session + session = channel_1.json_body["session"] + + # also complete the dummy auth + request_2, channel_2 = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) # type: SynapseRequest, FakeChannel + worker_render_2(request_2) + self.assertEqual(request_2.code, 200) + + # We're given a registered user. + self.assertEqual(channel_2.json_body["user_id"], "@user:test") -- cgit 1.5.1 From 77d2c054100f4b0ebe8a027d510a42ff5af09667 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Jul 2020 07:16:43 -0400 Subject: Add the option to validate the `iss` and `aud` claims for JWT logins. (#7827) --- changelog.d/7827.feature | 1 + docs/jwt.md | 16 ++++-- docs/sample_config.yaml | 21 ++++++++ synapse/config/jwt_config.py | 28 ++++++++++ synapse/rest/client/v1/login.py | 25 ++++++--- tests/rest/client/v1/test_login.py | 106 ++++++++++++++++++++++++++++++++++--- 6 files changed, 182 insertions(+), 15 deletions(-) create mode 100644 changelog.d/7827.feature (limited to 'tests') diff --git a/changelog.d/7827.feature b/changelog.d/7827.feature new file mode 100644 index 0000000000..0fd116e198 --- /dev/null +++ b/changelog.d/7827.feature @@ -0,0 +1 @@ +Add the option to validate the `iss` and `aud` claims for JWT logins. diff --git a/docs/jwt.md b/docs/jwt.md index 289d66b365..93b8d05236 100644 --- a/docs/jwt.md +++ b/docs/jwt.md @@ -20,8 +20,17 @@ follows: Note that the login type of `m.login.jwt` is supported, but is deprecated. This will be removed in a future version of Synapse. -The `jwt` should encode the local part of the user ID as the standard `sub` -claim. In the case that the token is not valid, the homeserver must respond with +The `token` field should include the JSON web token with the following claims: + +* The `sub` (subject) claim is required and should encode the local part of the + user ID. +* The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`) + claims are optional, but validated if present. +* The issuer (`iss`) claim is optional, but required and validated if configured. +* The audience (`aud`) claim is optional, but required and validated if configured. + Providing the audience claim when not configured will cause validation to fail. + +In the case that the token is not valid, the homeserver must respond with `401 Unauthorized` and an error code of `M_UNAUTHORIZED`. (Note that this differs from the token based logins which return a @@ -55,7 +64,8 @@ sample settings. Although JSON Web Tokens are typically generated from an external server, the examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly. -1. Configure Synapse with JWT logins: +1. Configure Synapse with JWT logins, note that this example uses a pre-shared + secret and an algorithm of HS256: ```yaml jwt_config: diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1a2d9fb153..9d94495464 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1812,6 +1812,9 @@ sso: # Each JSON Web Token needs to contain a "sub" (subject) claim, which is # used as the localpart of the mxid. # +# Additionally, the expiration time ("exp"), not before time ("nbf"), +# and issued at ("iat") claims are validated if present. +# # Note that this is a non-standard login type and client support is # expected to be non-existant. # @@ -1839,6 +1842,24 @@ sso: # #algorithm: "provided-by-your-issuer" + # The issuer to validate the "iss" claim against. + # + # Optional, if provided the "iss" claim will be required and + # validated for all JSON web tokens. + # + #issuer: "provided-by-your-issuer" + + # A list of audiences to validate the "aud" claim against. + # + # Optional, if provided the "aud" claim will be required and + # validated for all JSON web tokens. + # + # Note that if the "aud" claim is included in a JSON web token then + # validation will fail without configuring audiences. + # + #audiences: + # - "provided-by-your-issuer" + password_config: # Uncomment to disable password login diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py index fce96b4acf..3252ad9e7f 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py @@ -32,6 +32,11 @@ class JWTConfig(Config): self.jwt_secret = jwt_config["secret"] self.jwt_algorithm = jwt_config["algorithm"] + # The issuer and audiences are optional, if provided, it is asserted + # that the claims exist on the JWT. + self.jwt_issuer = jwt_config.get("issuer") + self.jwt_audiences = jwt_config.get("audiences") + try: import jwt @@ -42,6 +47,8 @@ class JWTConfig(Config): self.jwt_enabled = False self.jwt_secret = None self.jwt_algorithm = None + self.jwt_issuer = None + self.jwt_audiences = None def generate_config_section(self, **kwargs): return """\ @@ -52,6 +59,9 @@ class JWTConfig(Config): # Each JSON Web Token needs to contain a "sub" (subject) claim, which is # used as the localpart of the mxid. # + # Additionally, the expiration time ("exp"), not before time ("nbf"), + # and issued at ("iat") claims are validated if present. + # # Note that this is a non-standard login type and client support is # expected to be non-existant. # @@ -78,4 +88,22 @@ class JWTConfig(Config): # Required if 'enabled' is true. # #algorithm: "provided-by-your-issuer" + + # The issuer to validate the "iss" claim against. + # + # Optional, if provided the "iss" claim will be required and + # validated for all JSON web tokens. + # + #issuer: "provided-by-your-issuer" + + # A list of audiences to validate the "aud" claim against. + # + # Optional, if provided the "aud" claim will be required and + # validated for all JSON web tokens. + # + # Note that if the "aud" claim is included in a JSON web token then + # validation will fail without configuring audiences. + # + #audiences: + # - "provided-by-your-issuer" """ diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 64d5c58b65..326ffa0056 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -89,12 +89,19 @@ class LoginRestServlet(RestServlet): def __init__(self, hs): super(LoginRestServlet, self).__init__() self.hs = hs + + # JWT configuration variables. self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm + self.jwt_issuer = hs.config.jwt_issuer + self.jwt_audiences = hs.config.jwt_audiences + + # SSO configuration. self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() @@ -368,16 +375,22 @@ class LoginRestServlet(RestServlet): ) import jwt - from jwt.exceptions import InvalidTokenError try: payload = jwt.decode( - token, self.jwt_secret, algorithms=[self.jwt_algorithm] + token, + self.jwt_secret, + algorithms=[self.jwt_algorithm], + issuer=self.jwt_issuer, + audience=self.jwt_audiences, + ) + except jwt.PyJWTError as e: + # A JWT error occurred, return some info back to the client. + raise LoginError( + 401, + "JWT validation failed: %s" % (str(e),), + errcode=Codes.UNAUTHORIZED, ) - except jwt.ExpiredSignatureError: - raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) - except InvalidTokenError: - raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user = payload.get("sub", None) if user is None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2be7238b00..4413bb3932 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -514,16 +514,17 @@ class JWTTestCase(unittest.HomeserverTestCase): ] jwt_secret = "secret" + jwt_algorithm = "HS256" def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() self.hs.config.jwt_enabled = True self.hs.config.jwt_secret = self.jwt_secret - self.hs.config.jwt_algorithm = "HS256" + self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs def jwt_encode(self, token, secret=jwt_secret): - return jwt.encode(token, secret, "HS256").decode("ascii") + return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii") def jwt_login(self, *args): params = json.dumps( @@ -548,20 +549,28 @@ class JWTTestCase(unittest.HomeserverTestCase): channel = self.jwt_login({"sub": "frog"}, "notsecret") self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "Invalid JWT") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) def test_login_jwt_expired(self): channel = self.jwt_login({"sub": "frog", "exp": 864000}) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "JWT expired") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Signature has expired" + ) def test_login_jwt_not_before(self): now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "Invalid JWT") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: The token is not yet valid (nbf)", + ) def test_login_no_sub(self): channel = self.jwt_login({"username": "root"}) @@ -569,6 +578,88 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["error"], "Invalid JWT") + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "issuer": "test-issuer", + } + } + ) + def test_login_iss(self): + """Test validating the issuer claim.""" + # A valid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid issuer" + ) + + # Not providing an issuer. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "iss" claim', + ) + + def test_login_iss_no_config(self): + """Test providing an issuer claim without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "audiences": ["test-audience"], + } + } + ) + def test_login_aud(self): + """Test validating the audience claim.""" + # A valid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + + # Not providing an audience. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "aud" claim', + ) + + def test_login_aud_no_config(self): + """Test providing an audience without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + def test_login_no_token(self): params = json.dumps({"type": "org.matrix.login.jwt"}) request, channel = self.make_request(b"POST", LOGIN_URL, params) @@ -658,4 +749,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "Invalid JWT") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) -- cgit 1.5.1 From 491f0dab1ba5456f52b0710461fbaabc594ff1f5 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 14 Jul 2020 13:36:23 +0200 Subject: Add delete room admin endpoint (#7613) The Delete Room admin API allows server admins to remove rooms from server and block these rooms. `DELETE /_synapse/admin/v1/rooms/` It is a combination and improvement of "[Shutdown room](https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/shutdown_room.md)" and "[Purge room](https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/purge_room.md)" API. Fixes: #6425 It also fixes a bug in [synapse/storage/data_stores/main/room.py](synapse/storage/data_stores/main/room.py) in ` get_room_with_stats`. It should return `None` if the room is unknown. But it returns an `IndexError`. https://github.com/matrix-org/synapse/blob/901b1fa561e3cc661d78aa96d59802cf2078cb0d/synapse/storage/data_stores/main/room.py#L99-L105 Related to: - #5575 - https://github.com/Awesome-Technologies/synapse-admin/issues/17 Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/7613.feature | 1 + docs/admin_api/purge_room.md | 2 + docs/admin_api/rooms.md | 94 ++++++++ docs/admin_api/shutdown_room.md | 2 + synapse/handlers/room.py | 208 +++++++++++++++- synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/rooms.py | 157 ++++-------- synapse/server.py | 10 +- synapse/server.pyi | 2 + synapse/storage/data_stores/main/room.py | 7 +- tests/rest/admin/test_room.py | 395 +++++++++++++++++++++++++++++++ tests/storage/test_room.py | 8 + 12 files changed, 775 insertions(+), 113 deletions(-) create mode 100644 changelog.d/7613.feature (limited to 'tests') diff --git a/changelog.d/7613.feature b/changelog.d/7613.feature new file mode 100644 index 0000000000..b671dc2fcc --- /dev/null +++ b/changelog.d/7613.feature @@ -0,0 +1 @@ +Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms//delete`). Contributed by @dklimpel. diff --git a/docs/admin_api/purge_room.md b/docs/admin_api/purge_room.md index 64ea7b6a64..ae01a543c6 100644 --- a/docs/admin_api/purge_room.md +++ b/docs/admin_api/purge_room.md @@ -5,6 +5,8 @@ This API will remove all trace of a room from your database. All local users must have left the room before it can be removed. +See also: [Delete Room API](rooms.md#delete-room-api) + The API is: ``` diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 624e7745ba..3f26adc16c 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -318,3 +318,97 @@ Response: "state_events": 93534 } ``` + +# Delete Room API + +The Delete Room admin API allows server admins to remove rooms from server +and block these rooms. +It is a combination and improvement of "[Shutdown room](shutdown_room.md)" +and "[Purge room](purge_room.md)" API. + +Shuts down a room. Moves all local users and room aliases automatically to a +new room if `new_room_user_id` is set. Otherwise local users only +leave the room without any information. + +The new room will be created with the user specified by the `new_room_user_id` parameter +as room administrator and will contain a message explaining what happened. Users invited +to the new room will have power level `-10` by default, and thus be unable to speak. + +If `block` is `True` it prevents new joins to the old room. + +This API will remove all trace of the old room from your database after removing +all local users. +Depending on the amount of history being purged a call to the API may take +several minutes or longer. + +The local server will only have the power to move local user and room aliases to +the new room. Users on other servers will be unaffected. + +The API is: + +```json +POST /_synapse/admin/v1/rooms//delete +``` + +with a body of: +```json +{ + "new_room_user_id": "@someuser:example.com", + "room_name": "Content Violation Notification", + "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.", + "block": true +} +``` + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see [README.rst](README.rst). + +A response body like the following is returned: + +```json +{ + "kicked_users": [ + "@foobar:example.com" + ], + "failed_to_kick_users": [], + "local_aliases": [ + "#badroom:example.com", + "#evilsaloon:example.com" + ], + "new_room_id": "!newroomid:example.com" +} +``` + +## Parameters + +The following parameters should be set in the URL: + +* `room_id` - The ID of the room. + +The following JSON body parameters are available: + +* `new_room_user_id` - Optional. If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be moved into that + room. If not set, no new room will be created and the users will just be removed + from the old room. The user ID must be on the local server, but does not necessarily + have to belong to a registered user. +* `room_name` - Optional. A string representing the name of the room that new users will be + invited to. Defaults to `Content Violation Notification` +* `message` - Optional. A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly convey why the + original room was shut down. Defaults to `Sharing illegal content on this server + is not permitted and rooms in violation will be blocked.` +* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to + join the room. Defaults to `false`. + +The JSON body must not be empty. The body must be at least `{}`. + +## Response + +The following fields are returned in the JSON response body: + +* `kicked_users` - An array of users (`user_id`) that were kicked. +* `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. +* `local_aliases` - An array of strings representing the local aliases that were migrated from + the old room to the new. +* `new_room_id` - A string representing the room ID of the new room. diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md index 54ce1cd234..808caeec79 100644 --- a/docs/admin_api/shutdown_room.md +++ b/docs/admin_api/shutdown_room.md @@ -10,6 +10,8 @@ disallow any further invites or joins. The local server will only have the power to move local user and room aliases to the new room. Users on other servers will be unaffected. +See also: [Delete Room API](rooms.md#delete-room-api) + ## API You will need to authenticate with an access token for an admin user. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 950a84acd0..fb37d371ad 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -22,11 +22,12 @@ import logging import math import string from collections import OrderedDict -from typing import Tuple +from typing import Optional, Tuple from synapse.api.constants import ( EventTypes, JoinRules, + Membership, RoomCreationPreset, RoomEncryptionAlgorithms, ) @@ -43,9 +44,10 @@ from synapse.types import ( StateMap, StreamToken, UserID, + create_requester, ) from synapse.util import stringutils -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import Linearizer, maybe_awaitable from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client @@ -1089,3 +1091,205 @@ class RoomEventSource(object): def get_current_key_for_room(self, room_id): return self.store.get_room_events_max_id(room_id) + + +class RoomShutdownHandler(object): + + DEFAULT_MESSAGE = ( + "Sharing illegal content on this server is not permitted and rooms in" + " violation will be blocked." + ) + DEFAULT_ROOM_NAME = "Content Violation Notification" + + def __init__(self, hs): + self.hs = hs + self.room_member_handler = hs.get_room_member_handler() + self._room_creation_handler = hs.get_room_creation_handler() + self._replication = hs.get_replication_data_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self.state = hs.get_state_handler() + self.store = hs.get_datastore() + + async def shutdown_room( + self, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + ) -> dict: + """ + Shuts down a room. Moves all local users and room aliases automatically + to a new room if `new_room_user_id` is set. Otherwise local users only + leave the room without any information. + + The new room will be created with the user specified by the + `new_room_user_id` parameter as room administrator and will contain a + message explaining what happened. Users invited to the new room will + have power level `-10` by default, and thus be unable to speak. + + The local server will only have the power to move local user and room + aliases to the new room. Users on other servers will be unaffected. + + Args: + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action and put the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + + Returns: a dict containing the following keys: + kicked_users: An array of users (`user_id`) that were kicked. + failed_to_kick_users: + An array of users (`user_id`) that that were not kicked. + local_aliases: + An array of strings representing the local aliases that were + migrated from the old room to the new. + new_room_id: A string representing the room ID of the new room. + """ + + if not new_room_name: + new_room_name = self.DEFAULT_ROOM_NAME + if not message: + message = self.DEFAULT_MESSAGE + + if not RoomID.is_valid(room_id): + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + + if not await self.store.get_room(room_id): + raise NotFoundError("Unknown room id %s" % (room_id,)) + + # This will work even if the room is already blocked, but that is + # desirable in case the first attempt at blocking the room failed below. + if block: + await self.store.block_room(room_id, requester_user_id) + + if new_room_user_id is not None: + if not self.hs.is_mine_id(new_room_user_id): + raise SynapseError( + 400, "User must be our own: %s" % (new_room_user_id,) + ) + + room_creator_requester = create_requester(new_room_user_id) + + info, stream_id = await self._room_creation_handler.create_room( + room_creator_requester, + config={ + "preset": RoomCreationPreset.PUBLIC_CHAT, + "name": new_room_name, + "power_level_content_override": {"users_default": -10}, + }, + ratelimit=False, + ) + new_room_id = info["room_id"] + + logger.info( + "Shutting down room %r, joining to new room: %r", room_id, new_room_id + ) + + # We now wait for the create room to come back in via replication so + # that we can assume that all the joins/invites have propogated before + # we try and auto join below. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + else: + new_room_id = None + logger.info("Shutting down room %r", room_id) + + users = await self.state.get_current_users_in_room(room_id) + kicked_users = [] + failed_to_kick_users = [] + for user_id in users: + if not self.hs.is_mine_id(user_id): + continue + + logger.info("Kicking %r from %r...", user_id, room_id) + + try: + # Kick users from room + target_requester = create_requester(user_id) + _, stream_id = await self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=room_id, + action=Membership.LEAVE, + content={}, + ratelimit=False, + require_consent=False, + ) + + # Wait for leave to come in over replication before trying to forget. + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + + await self.room_member_handler.forget(target_requester.user, room_id) + + # Join users to new room + if new_room_user_id: + await self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=new_room_id, + action=Membership.JOIN, + content={}, + ratelimit=False, + require_consent=False, + ) + + kicked_users.append(user_id) + except Exception: + logger.exception( + "Failed to leave old room and join new room for %r", user_id + ) + failed_to_kick_users.append(user_id) + + # Send message in new room and move aliases + if new_room_user_id: + await self.event_creation_handler.create_and_send_nonmember_event( + room_creator_requester, + { + "type": "m.room.message", + "content": {"body": message, "msgtype": "m.text"}, + "room_id": new_room_id, + "sender": new_room_user_id, + }, + ratelimit=False, + ) + + aliases_for_room = await maybe_awaitable( + self.store.get_aliases_for_room(room_id) + ) + + await self.store.update_aliases_for_room( + room_id, new_room_id, requester_user_id + ) + else: + aliases_for_room = [] + + return { + "kicked_users": kicked_users, + "failed_to_kick_users": failed_to_kick_users, + "local_aliases": aliases_for_room, + "new_room_id": new_room_id, + } diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 9eda592de9..dc373bc5a3 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -35,6 +35,7 @@ from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( + DeleteRoomRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, RoomRestServlet, @@ -200,6 +201,7 @@ def register_servlets(hs, http_server): register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) + DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index e07c32118d..544be47060 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import List, Optional -from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset +from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -32,7 +33,6 @@ from synapse.rest.admin._base import ( ) from synapse.storage.data_stores.main.room import RoomSortOrder from synapse.types import RoomAlias, RoomID, UserID, create_requester -from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -46,20 +46,10 @@ class ShutdownRoomRestServlet(RestServlet): PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P[^/]+)") - DEFAULT_MESSAGE = ( - "Sharing illegal content on this server is not permitted and rooms in" - " violation will be blocked." - ) - def __init__(self, hs): self.hs = hs - self.store = hs.get_datastore() - self.state = hs.get_state_handler() - self._room_creation_handler = hs.get_room_creation_handler() - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - self._replication = hs.get_replication_data_handler() + self.room_shutdown_handler = hs.get_room_shutdown_handler() async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request) @@ -67,116 +57,65 @@ class ShutdownRoomRestServlet(RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ["new_room_user_id"]) - new_room_user_id = content["new_room_user_id"] - - room_creator_requester = create_requester(new_room_user_id) - - message = content.get("message", self.DEFAULT_MESSAGE) - room_name = content.get("room_name", "Content Violation Notification") - info, stream_id = await self._room_creation_handler.create_room( - room_creator_requester, - config={ - "preset": RoomCreationPreset.PUBLIC_CHAT, - "name": room_name, - "power_level_content_override": {"users_default": -10}, - }, - ratelimit=False, + ret = await self.room_shutdown_handler.shutdown_room( + room_id=room_id, + new_room_user_id=content["new_room_user_id"], + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=True, ) - new_room_id = info["room_id"] - requester_user_id = requester.user.to_string() + return (200, ret) - logger.info( - "Shutting down room %r, joining to new room: %r", room_id, new_room_id - ) - - # This will work even if the room is already blocked, but that is - # desirable in case the first attempt at blocking the room failed below. - await self.store.block_room(room_id, requester_user_id) - - # We now wait for the create room to come back in via replication so - # that we can assume that all the joins/invites have propogated before - # we try and auto join below. - # - # TODO: Currently the events stream is written to from master - await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id - ) - users = await self.state.get_current_users_in_room(room_id) - kicked_users = [] - failed_to_kick_users = [] - for user_id in users: - if not self.hs.is_mine_id(user_id): - continue +class DeleteRoomRestServlet(RestServlet): + """Delete a room from server. It is a combination and improvement of + shut down and purge room. + Shuts down a room by removing all local users from the room. + Blocking all future invites and joins to the room is optional. + If desired any local aliases will be repointed to a new room + created by `new_room_user_id` and kicked users will be auto + joined to the new room. + It will remove all trace of a room from the database. + """ - logger.info("Kicking %r from %r...", user_id, room_id) + PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete$") - try: - target_requester = create_requester(user_id) - _, stream_id = await self.room_member_handler.update_membership( - requester=target_requester, - target=target_requester.user, - room_id=room_id, - action=Membership.LEAVE, - content={}, - ratelimit=False, - require_consent=False, - ) - - # Wait for leave to come in over replication before trying to forget. - await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id - ) + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.room_shutdown_handler = hs.get_room_shutdown_handler() + self.pagination_handler = hs.get_pagination_handler() - await self.room_member_handler.forget(target_requester.user, room_id) + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) - await self.room_member_handler.update_membership( - requester=target_requester, - target=target_requester.user, - room_id=new_room_id, - action=Membership.JOIN, - content={}, - ratelimit=False, - require_consent=False, - ) + content = parse_json_object_from_request(request) - kicked_users.append(user_id) - except Exception: - logger.exception( - "Failed to leave old room and join new room for %r", user_id - ) - failed_to_kick_users.append(user_id) - - await self.event_creation_handler.create_and_send_nonmember_event( - room_creator_requester, - { - "type": "m.room.message", - "content": {"body": message, "msgtype": "m.text"}, - "room_id": new_room_id, - "sender": new_room_user_id, - }, - ratelimit=False, - ) + block = content.get("block", False) + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean, if given", + Codes.BAD_JSON, + ) - aliases_for_room = await maybe_awaitable( - self.store.get_aliases_for_room(room_id) + ret = await self.room_shutdown_handler.shutdown_room( + room_id=room_id, + new_room_user_id=content.get("new_room_user_id"), + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=block, ) - await self.store.update_aliases_for_room( - room_id, new_room_id, requester_user_id - ) + # Purge room + await self.pagination_handler.purge_room(room_id) - return ( - 200, - { - "kicked_users": kicked_users, - "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, - "new_room_id": new_room_id, - }, - ) + return (200, ret) class ListRoomRestServlet(RestServlet): diff --git a/synapse/server.py b/synapse/server.py index 6acce2e23f..d5ebaea7f7 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -73,7 +73,11 @@ from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.register import RegistrationHandler -from synapse.handlers.room import RoomContextHandler, RoomCreationHandler +from synapse.handlers.room import ( + RoomContextHandler, + RoomCreationHandler, + RoomShutdownHandler, +) from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler @@ -144,6 +148,7 @@ class HomeServer(object): "handlers", "auth", "room_creation_handler", + "room_shutdown_handler", "state_handler", "state_resolution_handler", "presence_handler", @@ -357,6 +362,9 @@ class HomeServer(object): def build_room_creation_handler(self): return RoomCreationHandler(self) + def build_room_shutdown_handler(self): + return RoomShutdownHandler(self) + def build_sendmail(self): return sendmail diff --git a/synapse/server.pyi b/synapse/server.pyi index fe8024d2d4..58cd099e6d 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -71,6 +71,8 @@ class HomeServer(object): pass def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass + def get_room_shutdown_handler(self) -> synapse.handlers.room.RoomShutdownHandler: + pass def get_event_creation_handler( self, ) -> synapse.handlers.message.EventCreationHandler: diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index c473cf158f..dace20e6db 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -118,7 +118,12 @@ class RoomWorkerStore(SQLBaseStore): WHERE room_id = ? """ txn.execute(sql, [room_id]) - res = self.db.cursor_to_dict(txn)[0] + # Catch error if sql returns empty result to return "None" instead of an error + try: + res = self.db.cursor_to_dict(txn)[0] + except IndexError: + return None + res["federatable"] = bool(res["federatable"]) res["public"] = bool(res["public"]) return res diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index ae6d05a043..a80537c4fc 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -151,6 +151,401 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): ) +class DeleteRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "POST", self.url, json.dumps({}), access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + url = "/_synapse/admin/v1/rooms/!unknown:test/delete" + + request, channel = self.make_request( + "POST", url, json.dumps({}), access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + url = "/_synapse/admin/v1/rooms/invalidroom/delete" + + request, channel = self.make_request( + "POST", url, json.dumps({}), access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom is not a legal room ID", channel.json_body["error"], + ) + + def test_new_room_user_does_not_exist(self): + """ + Tests that the user ID must be from local server but it does not have to exist. + """ + body = json.dumps({"new_room_user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("kicked_users", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + def test_new_room_user_is_not_local(self): + """ + Check that only local users can create new room to move members. + """ + body = json.dumps({"new_room_user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "User must be our own: @not:exist.bla", channel.json_body["error"], + ) + + def test_block_is_not_bool(self): + """ + If parameter `block` is not boolean, return an error + """ + body = json.dumps({"block": "NotBool"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_room_and_block(self): + """Test to purge a room and block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": True}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_purge_room_and_not_block(self): + """Test to purge a room and do not block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": False}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + ) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["new_room_id"], user_id=self.other_user + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["new_room_id"], user_id=self.other_user + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + # Assert we can no longer peek into the room + self._assert_peek(self.room_id, expect_code=403) + + def _is_blocked(self, room_id, expect=True): + """Assert that the room is blocked or not + """ + d = self.store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _has_no_members(self, room_id): + """Assert there is now no longer anyone in the room + """ + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def _is_member(self, room_id, user_id): + """Test that user is member of the room + """ + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertIn(user_id, users_in_room) + + def _is_purged(self, room_id): + """Test that the following tables have been purged of all rows related to the room. + """ + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + class PurgeRoomTestCase(unittest.HomeserverTestCase): """Test /purge_room admin API. """ diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 3b78d48896..b1dceb2918 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -55,6 +55,10 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_unknown_room(self): + self.assertIsNone((yield self.store.get_room("!uknown:test")),) + @defer.inlineCallbacks def test_get_room_with_stats(self): self.assertDictContainsSubset( @@ -66,6 +70,10 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room_with_stats(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_with_stats_unknown_room(self): + self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),) + class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks -- cgit 1.5.1 From 111e70d75c2e1e82f844e4a18a34ae579166dd9a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Jul 2020 07:10:21 -0400 Subject: Return the proper 403 Forbidden error during errors with JWT logins. (#7844) --- changelog.d/7844.bugfix | 1 + docs/jwt.md | 5 +---- synapse/rest/client/v1/login.py | 8 +++---- tests/rest/client/v1/test_login.py | 43 +++++++++++++++++++------------------- 4 files changed, 27 insertions(+), 30 deletions(-) create mode 100644 changelog.d/7844.bugfix (limited to 'tests') diff --git a/changelog.d/7844.bugfix b/changelog.d/7844.bugfix new file mode 100644 index 0000000000..ad296f1b3c --- /dev/null +++ b/changelog.d/7844.bugfix @@ -0,0 +1 @@ +Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`. diff --git a/docs/jwt.md b/docs/jwt.md index 93b8d05236..5be9fd26e3 100644 --- a/docs/jwt.md +++ b/docs/jwt.md @@ -31,10 +31,7 @@ The `token` field should include the JSON web token with the following claims: Providing the audience claim when not configured will cause validation to fail. In the case that the token is not valid, the homeserver must respond with -`401 Unauthorized` and an error code of `M_UNAUTHORIZED`. - -(Note that this differs from the token based logins which return a -`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.) +`403 Forbidden` and an error code of `M_FORBIDDEN`. As with other login types, there are additional fields (e.g. `device_id` and `initial_device_display_name`) which can be included in the above request. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 326ffa0056..379f668d6f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -371,7 +371,7 @@ class LoginRestServlet(RestServlet): token = login_submission.get("token", None) if token is None: raise LoginError( - 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED + 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN ) import jwt @@ -387,14 +387,12 @@ class LoginRestServlet(RestServlet): except jwt.PyJWTError as e: # A JWT error occurred, return some info back to the client. raise LoginError( - 401, - "JWT validation failed: %s" % (str(e),), - errcode=Codes.UNAUTHORIZED, + 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN, ) user = payload.get("sub", None) if user is None: - raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) + raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 4413bb3932..db52725cfe 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -547,8 +547,8 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_invalid_signature(self): channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Signature verification failed", @@ -556,8 +556,8 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_expired(self): channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Signature has expired" ) @@ -565,8 +565,8 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_not_before(self): now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: The token is not yet valid (nbf)", @@ -574,8 +574,8 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_no_sub(self): channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @override_config( @@ -597,16 +597,16 @@ class JWTTestCase(unittest.HomeserverTestCase): # An invalid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Invalid issuer" ) # Not providing an issuer. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], 'JWT validation failed: Token is missing the "iss" claim', @@ -637,16 +637,16 @@ class JWTTestCase(unittest.HomeserverTestCase): # An invalid audience. channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Invalid audience" ) # Not providing an audience. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], 'JWT validation failed: Token is missing the "aud" claim', @@ -655,7 +655,8 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_aud_no_config(self): """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Invalid audience" ) @@ -664,8 +665,8 @@ class JWTTestCase(unittest.HomeserverTestCase): params = json.dumps({"type": "org.matrix.login.jwt"}) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -747,8 +748,8 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def test_login_jwt_invalid_signature(self): channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Signature verification failed", -- cgit 1.5.1 From b11450dedc59b117ad23426b47f2465c459ea62a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Jul 2020 08:48:58 -0400 Subject: Convert E2E key and room key handlers to async/await. (#7851) --- changelog.d/7851.misc | 1 + synapse/handlers/e2e_keys.py | 147 ++++++-------- synapse/handlers/e2e_room_keys.py | 75 ++++--- tests/handlers/test_e2e_keys.py | 286 ++++++++++++++++----------- tests/handlers/test_e2e_room_keys.py | 373 +++++++++++++++++++++++------------ 5 files changed, 521 insertions(+), 361 deletions(-) create mode 100644 changelog.d/7851.misc (limited to 'tests') diff --git a/changelog.d/7851.misc b/changelog.d/7851.misc new file mode 100644 index 0000000000..e5cf540edf --- /dev/null +++ b/changelog.d/7851.misc @@ -0,0 +1 @@ +Convert E2E keys and room keys handlers to async/await. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index a7e60cbc26..361dd64cd2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -77,8 +77,7 @@ class E2eKeysHandler(object): ) @trace - @defer.inlineCallbacks - def query_devices(self, query_body, timeout, from_user_id): + async def query_devices(self, query_body, timeout, from_user_id): """ Handle a device key query from a client { @@ -124,7 +123,7 @@ class E2eKeysHandler(object): failures = {} results = {} if local_query: - local_result = yield self.query_local_devices(local_query) + local_result = await self.query_local_devices(local_query) for user_id, keys in local_result.items(): if user_id in local_query: results[user_id] = keys @@ -142,7 +141,7 @@ class E2eKeysHandler(object): ( user_ids_not_in_cache, remote_results, - ) = yield self.store.get_user_devices_from_cache(query_list) + ) = await self.store.get_user_devices_from_cache(query_list) for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) for device_id, device in devices.items(): @@ -161,14 +160,13 @@ class E2eKeysHandler(object): r[user_id] = remote_queries[user_id] # Get cached cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, from_user_id ) # Now fetch any devices that we don't have in our cache @trace - @defer.inlineCallbacks - def do_remote_query(destination): + async def do_remote_query(destination): """This is called when we are querying the device list of a user on a remote homeserver and their device list is not in the device list cache. If we share a room with this user and we're not querying for @@ -192,7 +190,7 @@ class E2eKeysHandler(object): if device_list: continue - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: continue @@ -201,11 +199,11 @@ class E2eKeysHandler(object): # done an initial sync on the device list so we do it now. try: if self._is_master: - user_devices = yield self.device_handler.device_list_updater.user_device_resync( + user_devices = await self.device_handler.device_list_updater.user_device_resync( user_id ) else: - user_devices = yield self._user_device_resync_client( + user_devices = await self._user_device_resync_client( user_id=user_id ) @@ -227,7 +225,7 @@ class E2eKeysHandler(object): destination_query.pop(user_id) try: - remote_result = yield self.federation.query_client_keys( + remote_result = await self.federation.query_client_keys( destination, {"device_keys": destination_query}, timeout=timeout ) @@ -251,7 +249,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(do_remote_query, destination) @@ -267,8 +265,7 @@ class E2eKeysHandler(object): return ret - @defer.inlineCallbacks - def get_cross_signing_keys_from_cache(self, query, from_user_id): + async def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database Args: @@ -289,7 +286,7 @@ class E2eKeysHandler(object): user_ids = list(query) - keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) for user_id, user_info in keys.items(): if user_info is None: @@ -315,8 +312,7 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def query_local_devices(self, query): + async def query_local_devices(self, query): """Get E2E device keys for local users Args: @@ -354,7 +350,7 @@ class E2eKeysHandler(object): # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = yield self.store.get_e2e_device_keys(local_query) + results = await self.store.get_e2e_device_keys(local_query) # Build the result structure for user_id, device_keys in results.items(): @@ -364,16 +360,15 @@ class E2eKeysHandler(object): log_kv(results) return result_dict - @defer.inlineCallbacks - def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys(self, query_body): """ Handle a device key query from a federated server """ device_keys_query = query_body.get("device_keys", {}) - res = yield self.query_local_devices(device_keys_query) + res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} # add in the cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, None ) @@ -382,8 +377,7 @@ class E2eKeysHandler(object): return ret @trace - @defer.inlineCallbacks - def claim_one_time_keys(self, query, timeout): + async def claim_one_time_keys(self, query, timeout): local_query = [] remote_queries = {} @@ -399,7 +393,7 @@ class E2eKeysHandler(object): set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) - results = yield self.store.claim_e2e_one_time_keys(local_query) + results = await self.store.claim_e2e_one_time_keys(local_query) json_result = {} failures = {} @@ -411,12 +405,11 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def claim_client_keys(destination): + async def claim_client_keys(destination): set_tag("destination", destination) device_keys = remote_queries[destination] try: - remote_result = yield self.federation.claim_client_keys( + remote_result = await self.federation.claim_client_keys( destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): @@ -429,7 +422,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(claim_client_keys, destination) @@ -454,9 +447,8 @@ class E2eKeysHandler(object): log_kv({"one_time_keys": json_result, "failures": failures}) return {"one_time_keys": json_result, "failures": failures} - @defer.inlineCallbacks @tag_args - def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user(self, user_id, device_id, keys): time_now = self.clock.time_msec() @@ -477,12 +469,12 @@ class E2eKeysHandler(object): } ) # TODO: Sign the JSON with the server key - changed = yield self.store.set_e2e_device_keys( + changed = await self.store.set_e2e_device_keys( user_id, device_id, time_now, device_keys ) if changed: # Only notify about device updates *if* the keys actually changed - yield self.device_handler.notify_device_update(user_id, [device_id]) + await self.device_handler.notify_device_update(user_id, [device_id]) else: log_kv({"message": "Not updating device_keys for user", "user_id": user_id}) one_time_keys = keys.get("one_time_keys", None) @@ -494,7 +486,7 @@ class E2eKeysHandler(object): "device_id": device_id, } ) - yield self._upload_one_time_keys_for_user( + await self._upload_one_time_keys_for_user( user_id, device_id, time_now, one_time_keys ) else: @@ -507,15 +499,14 @@ class E2eKeysHandler(object): # old access_token without an associated device_id. Either way, we # need to double-check the device is registered to avoid ending up with # keys without a corresponding device. - yield self.device_handler.check_device_registered(user_id, device_id) + await self.device_handler.check_device_registered(user_id, device_id) - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + result = await self.store.count_e2e_one_time_keys(user_id, device_id) set_tag("one_time_key_counts", result) return {"one_time_key_counts": result} - @defer.inlineCallbacks - def _upload_one_time_keys_for_user( + async def _upload_one_time_keys_for_user( self, user_id, device_id, time_now, one_time_keys ): logger.info( @@ -533,7 +524,7 @@ class E2eKeysHandler(object): key_list.append((algorithm, key_id, key_obj)) # First we check if we have already persisted any of the keys. - existing_key_map = yield self.store.get_e2e_one_time_keys( + existing_key_map = await self.store.get_e2e_one_time_keys( user_id, device_id, [k_id for _, k_id, _ in key_list] ) @@ -556,10 +547,9 @@ class E2eKeysHandler(object): ) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) - yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) + await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - @defer.inlineCallbacks - def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user(self, user_id, keys): """Upload signing keys for cross-signing Args: @@ -574,7 +564,7 @@ class E2eKeysHandler(object): _check_cross_signing_key(master_key, user_id, "master") else: - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") # if there is no master key, then we can't do anything, because all the # other cross-signing keys need to be signed by the master key @@ -613,10 +603,10 @@ class E2eKeysHandler(object): # if everything checks out, then store the keys and send notifications deviceids = [] if "master_key" in keys: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) deviceids.append(master_verify_key.version) if "self_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) try: @@ -626,23 +616,22 @@ class E2eKeysHandler(object): except ValueError: raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) if "user_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "user_signing", user_signing_key ) # the signature stream matches the semantics that we want for # user-signing key updates: only the user themselves is notified of # their own user-signing key updates - yield self.device_handler.notify_user_signature_update(user_id, [user_id]) + await self.device_handler.notify_user_signature_update(user_id, [user_id]) # master key and self-signing key updates match the semantics of device # list updates: all users who share an encrypted room are notified if len(deviceids): - yield self.device_handler.notify_device_update(user_id, deviceids) + await self.device_handler.notify_device_update(user_id, deviceids) return {} - @defer.inlineCallbacks - def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys(self, user_id, signatures): """Upload device signatures for cross-signing Args: @@ -667,13 +656,13 @@ class E2eKeysHandler(object): self_signatures = signatures.get(user_id, {}) other_signatures = {k: v for k, v in signatures.items() if k != user_id} - self_signature_list, self_failures = yield self._process_self_signatures( + self_signature_list, self_failures = await self._process_self_signatures( user_id, self_signatures ) signature_list.extend(self_signature_list) failures.update(self_failures) - other_signature_list, other_failures = yield self._process_other_signatures( + other_signature_list, other_failures = await self._process_other_signatures( user_id, other_signatures ) signature_list.extend(other_signature_list) @@ -681,21 +670,20 @@ class E2eKeysHandler(object): # store the signature, and send the appropriate notifications for sync logger.debug("upload signature failures: %r", failures) - yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) + await self.store.store_e2e_cross_signing_signatures(user_id, signature_list) self_device_ids = [item.target_device_id for item in self_signature_list] if self_device_ids: - yield self.device_handler.notify_device_update(user_id, self_device_ids) + await self.device_handler.notify_device_update(user_id, self_device_ids) signed_users = [item.target_user_id for item in other_signature_list] if signed_users: - yield self.device_handler.notify_user_signature_update( + await self.device_handler.notify_user_signature_update( user_id, signed_users ) return {"failures": failures} - @defer.inlineCallbacks - def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures(self, user_id, signatures): """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -728,7 +716,7 @@ class E2eKeysHandler(object): _, self_signing_key_id, self_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing") # get our master key, since we may have received a signature of it. # We need to fetch it here so that we know what its key ID is, so @@ -738,12 +726,12 @@ class E2eKeysHandler(object): master_key, _, master_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "master") # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what # was sent if the device was signed - devices = yield self.store.get_e2e_device_keys([(user_id, None)]) + devices = await self.store.get_e2e_device_keys([(user_id, None)]) if user_id not in devices: raise NotFoundError("No device keys found") @@ -853,8 +841,7 @@ class E2eKeysHandler(object): return master_key_signature_list - @defer.inlineCallbacks - def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures(self, user_id, signatures): """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. @@ -882,7 +869,7 @@ class E2eKeysHandler(object): user_signing_key, user_signing_key_id, user_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing") except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): @@ -905,7 +892,7 @@ class E2eKeysHandler(object): master_key, master_key_id, _, - ) = yield self._get_e2e_cross_signing_verify_key( + ) = await self._get_e2e_cross_signing_verify_key( target_user, "master", user_id ) @@ -958,8 +945,7 @@ class E2eKeysHandler(object): return signature_list, failures - @defer.inlineCallbacks - def _get_e2e_cross_signing_verify_key( + async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None ): """Fetch locally or remotely query for a cross-signing public key. @@ -983,7 +969,7 @@ class E2eKeysHandler(object): SynapseError: if `user_id` is invalid """ user = UserID.from_string(user_id) - key = yield self.store.get_e2e_cross_signing_key( + key = await self.store.get_e2e_cross_signing_key( user_id, key_type, from_user_id ) @@ -1009,15 +995,14 @@ class E2eKeysHandler(object): key, key_id, verify_key, - ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type) + ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type) if key is None: raise NotFoundError("No %s key found for %s" % (key_type, user_id)) return key, key_id, verify_key - @defer.inlineCallbacks - def _retrieve_cross_signing_keys_for_remote_user( + async def _retrieve_cross_signing_keys_for_remote_user( self, user: UserID, desired_key_type: str, ): """Queries cross-signing keys for a remote user and saves them to the database @@ -1035,7 +1020,7 @@ class E2eKeysHandler(object): If the key cannot be retrieved, all values in the tuple will instead be None. """ try: - remote_result = yield self.federation.query_user_devices( + remote_result = await self.federation.query_user_devices( user.domain, user.to_string() ) except Exception as e: @@ -1101,14 +1086,14 @@ class E2eKeysHandler(object): desired_key_id = key_id # At the same time, store this key in the db for subsequent queries - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user.to_string(), key_type, key_content ) # Notify clients that new devices for this user have been discovered if retrieved_device_ids: # XXX is this necessary? - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user.to_string(), retrieved_device_ids ) @@ -1250,8 +1235,7 @@ class SigningKeyEduUpdater(object): iterable=True, ) - @defer.inlineCallbacks - def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update(self, origin, edu_content): """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. @@ -1268,7 +1252,7 @@ class SigningKeyEduUpdater(object): logger.warning("Got signing key update edu for %r from %r", user_id, origin) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -1278,10 +1262,9 @@ class SigningKeyEduUpdater(object): (master_key, self_signing_key) ) - yield self._handle_signing_key_updates(user_id) + await self._handle_signing_key_updates(user_id) - @defer.inlineCallbacks - def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id): """Actually handle pending updates. Args: @@ -1291,7 +1274,7 @@ class SigningKeyEduUpdater(object): device_handler = self.e2e_keys_handler.device_handler device_list_updater = device_handler.device_list_updater - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -1302,9 +1285,9 @@ class SigningKeyEduUpdater(object): logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = yield device_list_updater.process_cross_signing_key_update( + new_device_ids = await device_list_updater.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + new_device_ids - yield device_handler.notify_device_update(user_id, device_ids) + await device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f55470a707..0bb983dc28 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( Codes, NotFoundError, @@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object): self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - @defer.inlineCallbacks - def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. @@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object): # we deliberately take the lock to get keys so that changing the version # works atomically - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - yield self.store.get_e2e_room_keys_version_info(user_id, version) + await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - results = yield self.store.get_e2e_room_keys( + results = await self.store.get_e2e_room_keys( user_id, version, room_id, session_id ) @@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object): return results @trace - @defer.inlineCallbacks - def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. @@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object): """ # lock for consistency with uploading - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object): else: raise - yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) version_etag = version_info["etag"] + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @trace - @defer.inlineCallbacks - def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys(self, user_id, version, room_keys): """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). @@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # XXX: perhaps we should use a finer grained lock here? - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # Check that the version we're trying to upload is the current version try: - version_info = yield self.store.get_e2e_room_keys_version_info(user_id) + version_info = await self.store.get_e2e_room_keys_version_info(user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Version '%s' not found" % (version,)) @@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object): if version_info["version"] != version: # Check that the version we're trying to upload actually exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) # if we get this far, the version must exist @@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object): # submitted. Then compare them with the submitted keys. If the # key is new, insert it; if the key should be updated, then update # it; otherwise, drop it. - existing_keys = yield self.store.get_e2e_room_keys_multi( + existing_keys = await self.store.get_e2e_room_keys_multi( user_id, version, room_keys["rooms"] ) to_insert = [] # batch the inserts together @@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object): # updates are done one at a time in the DB, so send # updates right away rather than batching them up, # like we do with the inserts - yield self.store.update_e2e_room_key( + await self.store.update_e2e_room_key( user_id, version, room_id, session_id, room_key ) changed = True @@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object): changed = True if len(to_insert): - yield self.store.add_e2e_room_keys(user_id, version, to_insert) + await self.store.add_e2e_room_keys(user_id, version, to_insert) version_etag = version_info["etag"] if changed: version_etag = version_etag + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @staticmethod @@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object): return True @trace - @defer.inlineCallbacks - def create_version(self, user_id, version_info): + async def create_version(self, user_id, version_info): """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. @@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # lock everyone out until we've switched version - with (yield self._upload_linearizer.queue(user_id)): - new_version = yield self.store.create_e2e_room_keys_version( + with (await self._upload_linearizer.queue(user_id)): + new_version = await self.store.create_e2e_room_keys_version( user_id, version_info ) return new_version - @defer.inlineCallbacks - def get_version_info(self, user_id, version=None): + async def get_version_info(self, user_id, version=None): """Get the info about a given version of the user's backup Args: @@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object): } """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - res = yield self.store.get_e2e_room_keys_version_info(user_id, version) + res = await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) + res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"]) res["etag"] = str(res["etag"]) return res @trace - @defer.inlineCallbacks - def delete_version(self, user_id, version=None): + async def delete_version(self, user_id, version=None): """Deletes a given version of the user's e2e_room_keys backup Args: @@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object): NotFoundError: if this backup version doesn't exist """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - yield self.store.delete_e2e_room_keys_version(user_id, version) + await self.store.delete_e2e_room_keys_version(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") @@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object): raise @trace - @defer.inlineCallbacks - def update_version(self, user_id, version, version_info): + async def update_version(self, user_id, version, version_info): """Update the info about a given version of the user's backup Args: @@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object): raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - old_info = yield self.store.get_e2e_room_keys_version_info( + old_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object): if old_info["algorithm"] != version_info["algorithm"]: raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, version_info ) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 1acf287ca4..cdd093ffa8 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): """If the user has no devices, we expect an empty list. """ local_user = "@boris:" + self.hs.hostname - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -60,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) @@ -84,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + ) ) self.fail("No error when changing string key") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + ) ) self.fail("No error when replacing dict key with string") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"one_time_keys": {"alg1:k1": {"key": "key"}}}, + ) ) self.fail("No error when replacing string key with dict") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, - device_id, - { - "one_time_keys": { - "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} - } - }, + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + { + "one_time_keys": { + "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} + } + }, + ) ) self.fail("No error when replacing dict key") except errors.SynapseError: @@ -133,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_id = "xyz" keys = {"alg1:k1": "key1"} - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) - res2 = yield self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + res2 = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) ) self.assertEqual( res2, @@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) keys2 = { "master_key": { @@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys2) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys2) + ) - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) @@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", ) - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) # upload two device keys, which will be signed later by the self-signing key device_key_1 = { @@ -245,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "signatures": {local_user: {"ed25519:def": "base64+signature"}}, } - yield self.handler.upload_keys_for_user( - local_user, "abc", {"device_keys": device_key_1} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "abc", {"device_keys": device_key_1} + ) ) - yield self.handler.upload_keys_for_user( - local_user, "def", {"device_keys": device_key_2} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "def", {"device_keys": device_key_2} + ) ) # sign the first device key and upload it del device_key_1["signatures"] sign.sign_json(device_key_1, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1}} + ) ) # sign the second device key and upload both device keys. The server @@ -264,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # signature for it del device_key_2["signatures"] sign.sign_json(device_key_2, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + ) ) device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"] @@ -292,7 +328,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) res = None try: @@ -305,7 +343,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): res = e.code self.assertEqual(res, 400) - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -331,8 +371,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" ) - yield self.handler.upload_keys_for_user( - local_user, device_id, {"device_keys": device_key} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"device_keys": device_key} + ) ) # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 @@ -372,7 +414,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_signing_key": usersigning_key, "self_signing_key": selfsigning_key, } - yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + ) # set up another user with a master key. This user will be signed by # the first user @@ -384,76 +428,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "usage": ["master"], "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, } - yield self.handler.upload_signing_keys_for_user( - other_user, {"master_key": other_master_key} + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user( + other_user, {"master_key": other_master_key} + ) ) # test various signature failures (see below) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: { - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - device_id: { - "user_id": local_user, - "device_id": device_id, - "algorithms": [ - "m.olm.curve25519-aes-sha2", - RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, - ], - "keys": { - "curve25519:xyz": "curve25519+key", - # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA - "ed25519:xyz": device_pubkey, - }, - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + device_id: { + "user_id": local_user, + "device_id": device_id, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], + "keys": { + "curve25519:xyz": "curve25519+key", + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + "ed25519:xyz": device_pubkey, + }, + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - }, - # fails because device is unknown - # should fail with NOT_FOUND - "unknown": { - "user_id": local_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + # fails because device is unknown + # should fail with NOT_FOUND + "unknown": { + "user_id": local_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - }, - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - master_pubkey: { - "user_id": local_user, - "usage": ["master"], - "keys": {"ed25519:" + master_pubkey: master_pubkey}, - "signatures": { - local_user: {"ed25519:" + device_pubkey: "something"} + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + master_pubkey: { + "user_id": local_user, + "usage": ["master"], + "keys": {"ed25519:" + master_pubkey: master_pubkey}, + "signatures": { + local_user: {"ed25519:" + device_pubkey: "something"} + }, }, }, - }, - other_user: { - # fails because the device is not the user's master-signing key - # should fail with NOT_FOUND - "unknown": { - "user_id": other_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} + other_user: { + # fails because the device is not the user's master-signing key + # should fail with NOT_FOUND + "unknown": { + "user_id": other_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, - }, - other_master_pubkey: { - # fails because the key doesn't match what the server has - # should fail with UNKNOWN - "user_id": other_user, - "usage": ["master"], - "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, - "something": "random", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} + other_master_pubkey: { + # fails because the key doesn't match what the server has + # should fail with UNKNOWN + "user_id": other_user, + "usage": ["master"], + "keys": { + "ed25519:" + other_master_pubkey: other_master_pubkey + }, + "something": "random", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, }, }, - }, + ) ) user_failures = ret["failures"][local_user] @@ -478,19 +536,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase): sign.sign_json(device_key, local_user, selfsigning_signing_key) sign.sign_json(master_key, local_user, device_signing_key) sign.sign_json(other_master_key, local_user, usersigning_signing_key) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: {device_id: device_key, master_pubkey: master_key}, - other_user: {other_master_pubkey: other_master_key}, - }, + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: {device_id: device_key, master_pubkey: master_key}, + other_user: {other_master_pubkey: other_master_key}, + }, + ) ) self.assertEqual(ret["failures"], {}) # fetch the signed keys/devices and make sure that the signatures are there - ret = yield self.handler.query_devices( - {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ret = yield defer.ensureDeferred( + self.handler.query_devices( + {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ) ) self.assertEqual( diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 822ea42dde..3362050ce0 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user) + yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) version_etag = res["etag"] self.assertIsInstance(version_etag, str) del res["etag"] @@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # check we can retrieve it as a specific version - res = yield self.handler.get_version_info(self.local_user, "1") + res = yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) self.assertEqual(res["etag"], version_etag) del res["etag"] self.assertDictEqual( @@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # upload a new one... - res = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(res, "2") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -149,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_version(self): """Check that we can update versions. """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version, - }, + res = yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) ) self.assertDictEqual(res, {}) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -185,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.update_version( - self.local_user, - "1", - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -202,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_omitted_version(self): """Check that the update succeeds if the version is missing from the body """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) ) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, @@ -234,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_bad_version(self): """Check that we get a 400 if the version in the body doesn't match """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.delete_version(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user) + yield defer.ensureDeferred(self.handler.delete_version(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can delete it - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1")) # check that it's gone res = None try: - yield self.handler.get_version_info(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_room_keys(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, @@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.upload_room_keys( - self.local_user, "no_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "no_version", room_keys) ) except errors.SynapseError as e: res = e.code @@ -343,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.upload_room_keys( - self.local_user, "bogus_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys( + self.local_user, "bogus_version", room_keys + ) ) except errors.SynapseError as e: res = e.code @@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - version = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(version, "2") res = None try: - yield self.handler.upload_room_keys(self.local_user, "1", room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "1", room_keys) + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 403) @@ -388,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given room - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, room_keys) @@ -415,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) # get the etag to compare to future versions - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) backup_etag = res["etag"] self.assertEqual(res["count"], 1) @@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # test that increasing the message_index doesn't replace the existing session new_room_key["first_message_index"] = 2 new_room_key["session_data"] = "new" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should NOT be equal now, since the key changed - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertNotEqual(res["etag"], backup_etag) backup_etag = res["etag"] @@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # with a lower forwarding count new_room_key["forwarded_count"] = 2 new_room_key["session_data"] = "other" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # TODO: check edge cases as well as the common variations here @@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") # check for bulk-delete - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys(self.local_user, version) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys(self.local_user, version) + ) + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) -- cgit 1.5.1 From f13061d5153eca9bd7054d5b89ade41f3a430f3b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 Jul 2020 15:27:35 +0100 Subject: Fix client reader sharding tests (#7853) * Fix client reader sharding tests * Newsfile * Fix typing * Update changelog.d/7853.misc Co-authored-by: Patrick Cloke * Move mocking of http_client to tests Co-authored-by: Patrick Cloke --- changelog.d/7853.misc | 1 + synapse/http/client.py | 24 ++- synapse/server.pyi | 5 + tests/replication/_base.py | 168 ++++++++++++++++++- tests/replication/test_client_reader_shard.py | 59 ++----- tests/replication/test_federation_sender_shard.py | 191 ++++++++-------------- tests/server.py | 26 ++- 7 files changed, 300 insertions(+), 174 deletions(-) create mode 100644 changelog.d/7853.misc (limited to 'tests') diff --git a/changelog.d/7853.misc b/changelog.d/7853.misc new file mode 100644 index 0000000000..b4f614084d --- /dev/null +++ b/changelog.d/7853.misc @@ -0,0 +1 @@ +Add support for handling registration requests across multiple client reader workers. diff --git a/synapse/http/client.py b/synapse/http/client.py index 505872ee90..b80681135e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -31,6 +31,7 @@ from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IResolutionReceiver, ) +from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone from twisted.web.client import Agent, HTTPConnectionPool, readBody @@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): return False +_EPSILON = 0.00000001 + + +def _make_scheduler(reactor): + """Makes a schedular suitable for a Cooperator using the given reactor. + + (This is effectively just a copy from `twisted.internet.task`) + """ + + def _scheduler(x): + return reactor.callLater(_EPSILON, x) + + return _scheduler + + class IPBlacklistingResolver(object): """ A proxy for reactor.nameResolver which only produces non-blacklisted IP @@ -212,6 +228,10 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) + # We use this for our body producers to ensure that they use the correct + # reactor. + self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor())) + self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: @@ -292,7 +312,9 @@ class SimpleHttpClient(object): try: body_producer = None if data is not None: - body_producer = QuieterFileBodyProducer(BytesIO(data)) + body_producer = QuieterFileBodyProducer( + BytesIO(data), cooperator=self._cooperator, + ) request_deferred = treq.request( method, diff --git a/synapse/server.pyi b/synapse/server.pyi index 58cd099e6d..cd50c721b8 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -20,6 +20,7 @@ import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password import synapse.http.client +import synapse.http.matrixfederationclient import synapse.notifier import synapse.push.pusherpool import synapse.replication.tcp.client @@ -143,3 +144,7 @@ class HomeServer(object): pass def get_replication_streams(self) -> Dict[str, Stream]: pass + def get_http_client( + self, + ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient: + pass diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 9d4f0bbe44..06575ba0a6 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import attr @@ -26,8 +26,9 @@ from synapse.app.generic_worker import ( GenericWorkerReplicationHandler, GenericWorkerServer, ) +from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest -from synapse.replication.http import streams +from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -35,7 +36,7 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeTransport +from tests.server import FakeTransport, render logger = logging.getLogger(__name__) @@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(request.method, b"GET") +class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests running multiple workers. + + Automatically handle HTTP replication requests from workers to master, + unlike `BaseStreamTestCase`. + """ + + servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]] + + def setUp(self): + super().setUp() + + # build a replication server + self.server_factory = ReplicationStreamProtocolFactory(self.hs) + self.streamer = self.hs.get_replication_streamer() + + store = self.hs.get_datastore() + self.database = store.db + + self.reactor.lookups["testserv"] = "1.2.3.4" + + self._worker_hs_to_resource = {} + + # When we see a connection attempt to the master replication listener we + # automatically set up the connection. This is so that tests don't + # manually have to go and explicitly set it up each time (plus sometimes + # it is impossible to write the handling explicitly in the tests). + self.reactor.add_tcp_client_callback( + "1.2.3.4", 8765, self._handle_http_replication_attempt + ) + + def create_test_json_resource(self): + """Overrides `HomeserverTestCase.create_test_json_resource`. + """ + # We override this so that it automatically registers all the HTTP + # replication servlets, without having to explicitly do that in all + # subclassses. + + resource = ReplicationRestResource(self.hs) + + for servlet in self.servlets: + servlet(self.hs, resource) + + return resource + + def make_worker_hs( + self, worker_app: str, extra_config: dict = {}, **kwargs + ) -> HomeServer: + """Make a new worker HS instance, correctly connecting replcation + stream to the master HS. + + Args: + worker_app: Type of worker, e.g. `synapse.app.federation_sender`. + extra_config: Any extra config to use for this instances. + **kwargs: Options that get passed to `self.setup_test_homeserver`, + useful to e.g. pass some mocks for things like `http_client` + + Returns: + The new worker HomeServer instance. + """ + + config = self._get_worker_hs_config() + config["worker_app"] = worker_app + config.update(extra_config) + + worker_hs = self.setup_test_homeserver( + homeserverToUse=GenericWorkerServer, + config=config, + reactor=self.reactor, + **kwargs + ) + + store = worker_hs.get_datastore() + store.db._db_pool = self.database._db_pool + + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) + + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) + + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) + + # Set up a resource for the worker + resource = ReplicationRestResource(self.hs) + + for servlet in self.servlets: + servlet(worker_hs, resource) + + self._worker_hs_to_resource[worker_hs] = resource + + return worker_hs + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): + render(request, self._worker_hs_to_resource[worker_hs], self.reactor) + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump() + + def _handle_http_replication_attempt(self): + """Handles a connection attempt to the master replication HTTP + listener. + """ + + # We should have at least one outbound connection attempt, where the + # last is one to the HTTP repication IP/port. + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8765) + + # Set up client side protocol + client_protocol = client_factory.buildProtocol(None) + + request_factory = OneShotRequestFactory() + + # Set up the server side protocol + channel = _PushHTTPChannel(self.reactor) + channel.requestFactory = request_factory + channel.site = self.site + + # Connect client to server and vice versa. + client_to_server_transport = FakeTransport( + channel, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, channel + ) + channel.makeConnection(server_to_client_transport) + + # Note: at this point we've wired everything up, but we need to return + # before the data starts flowing over the connections as this is called + # inside `connecTCP` before the connection has been passed back to the + # code that requested the TCP connection. + + class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" @@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel): # We need to manually stop the _PullToPushProducer. self._pull_to_push_producer.stop() + def checkPersistence(self, request, version): + """Check whether the connection can be re-used + """ + # We hijack this to always say no for ease of wiring stuff up in + # `handle_http_replication_attempt`. + request.responseHeaders.setRawHeaders(b"connection", [b"close"]) + return False + class _PullToPushProducer: """A push producer that wraps a pull producer. diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index b7d753e0a3..86c03fd89c 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -15,63 +15,26 @@ import logging from synapse.api.constants import LoginType -from synapse.app.generic_worker import GenericWorkerServer -from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest.client.v2_alpha import register -from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker -from tests.server import FakeChannel, render +from tests.server import FakeChannel logger = logging.getLogger(__name__) -class ClientReaderTestCase(unittest.HomeserverTestCase): +class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): """Base class for tests of the replication streams""" - servlets = [ - register.register_servlets, - ] + servlets = [register.register_servlets] def prepare(self, reactor, clock, hs): - # build a replication server - self.server_factory = ReplicationStreamProtocolFactory(hs) - self.streamer = hs.get_replication_streamer() - - store = hs.get_datastore() - self.database = store.db - self.recaptcha_checker = DummyRecaptchaChecker(hs) auth_handler = hs.get_auth_handler() auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - self.reactor.lookups["testserv"] = "1.2.3.4" - - def make_worker_hs(self, extra_config={}): - config = self._get_worker_hs_config() - config.update(extra_config) - - worker_hs = self.setup_test_homeserver( - homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor, - ) - - store = worker_hs.get_datastore() - store.db._db_pool = self.database._db_pool - - # Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource. - resource = JsonResource(self.hs) - - for servlet in self.servlets: - servlet(worker_hs, resource) - - # Essentially HomeserverTestCase.render. - def _render(request): - render(request, self.resource, self.reactor) - - return worker_hs, _render - def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.client_reader" @@ -82,14 +45,14 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): def test_register_single_worker(self): """Test that registration works when using a single client reader worker. """ - _, worker_render = self.make_worker_hs() + worker_hs = self.make_worker_hs("synapse.app.client_reader") request_1, channel_1 = self.make_request( "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - worker_render(request_1) + self.render_on_worker(worker_hs, request_1) self.assertEqual(request_1.code, 401) # Grab the session @@ -99,7 +62,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): request_2, channel_2 = self.make_request( "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} ) # type: SynapseRequest, FakeChannel - worker_render(request_2) + self.render_on_worker(worker_hs, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. @@ -108,15 +71,15 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): def test_register_multi_worker(self): """Test that registration works when using multiple client reader workers. """ - _, worker_render_1 = self.make_worker_hs() - _, worker_render_2 = self.make_worker_hs() + worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") + worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") request_1, channel_1 = self.make_request( "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - worker_render_1(request_1) + self.render_on_worker(worker_hs_1, request_1) self.assertEqual(request_1.code, 401) # Grab the session @@ -126,7 +89,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): request_2, channel_2 = self.make_request( "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} ) # type: SynapseRequest, FakeChannel - worker_render_2(request_2) + self.render_on_worker(worker_hs_2, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 519a2dc510..8d4dbf232e 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -19,132 +19,40 @@ from mock import Mock from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.app.generic_worker import GenericWorkerServer from synapse.events.builder import EventBuilderFactory -from synapse.replication.http import streams -from synapse.replication.tcp.handler import ReplicationCommandHandler -from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room from synapse.types import UserID -from tests import unittest -from tests.server import FakeTransport +from tests.replication._base import BaseMultiWorkerStreamTestCase logger = logging.getLogger(__name__) -class BaseStreamTestCase(unittest.HomeserverTestCase): - """Base class for tests of the replication streams""" - +class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): servlets = [ - streams.register_servlets, + login.register_servlets, + register_servlets_for_client_rest_resource, + room.register_servlets, ] - def prepare(self, reactor, clock, hs): - # build a replication server - self.server_factory = ReplicationStreamProtocolFactory(hs) - self.streamer = hs.get_replication_streamer() - - store = hs.get_datastore() - self.database = store.db - - self.reactor.lookups["testserv"] = "1.2.3.4" - def default_config(self): conf = super().default_config() conf["send_federation"] = False return conf - def make_worker_hs(self, extra_config={}): - config = self._get_worker_hs_config() - config.update(extra_config) - - mock_federation_client = Mock(spec=["put_json"]) - mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) - - worker_hs = self.setup_test_homeserver( - http_client=mock_federation_client, - homeserverToUse=GenericWorkerServer, - config=config, - reactor=self.reactor, - ) - - store = worker_hs.get_datastore() - store.db._db_pool = self.database._db_pool - - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, "client", "test", self.clock, repl_handler, - ) - server = self.server_factory.buildProtocol(None) - - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) - - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) - - return worker_hs - - def _get_worker_hs_config(self) -> dict: - config = self.default_config() - config["worker_app"] = "synapse.app.federation_sender" - config["worker_replication_host"] = "testserv" - config["worker_replication_http_port"] = "8765" - return config - - def replicate(self): - """Tell the master side of replication that something has happened, and then - wait for the replication to occur. - """ - self.streamer.on_notifier_poke() - self.pump() - - def create_room_with_remote_server(self, user, token, remote_server="other_server"): - room = self.helper.create_room_as(user, tok=token) - store = self.hs.get_datastore() - federation = self.hs.get_handlers().federation_handler - - prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) - room_version = self.get_success(store.get_room_version(room)) - - factory = EventBuilderFactory(self.hs) - factory.hostname = remote_server - - user_id = UserID("user", remote_server).to_string() - - event_dict = { - "type": EventTypes.Member, - "state_key": user_id, - "content": {"membership": Membership.JOIN}, - "sender": user_id, - "room_id": room, - } - - builder = factory.for_room_version(room_version, event_dict) - join_event = self.get_success(builder.build(prev_event_ids)) - - self.get_success(federation.on_send_join_request(remote_server, join_event)) - self.replicate() - - return room - - -class FederationSenderTestCase(BaseStreamTestCase): - servlets = [ - login.register_servlets, - register_servlets_for_client_rest_resource, - room.register_servlets, - ] - def test_send_event_single_sender(self): """Test that using a single federation sender worker correctly sends a new event. """ - worker_hs = self.make_worker_hs({"send_federation": True}) - mock_client = worker_hs.get_http_client() + mock_client = Mock(spec=["put_json"]) + mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + + self.make_worker_hs( + "synapse.app.federation_sender", + {"send_federation": True}, + http_client=mock_client, + ) user = self.register_user("user", "pass") token = self.login("user", "pass") @@ -165,23 +73,29 @@ class FederationSenderTestCase(BaseStreamTestCase): """Test that using two federation sender workers correctly sends new events. """ - worker1 = self.make_worker_hs( + mock_client1 = Mock(spec=["put_json"]) + mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client1, ) - mock_client1 = worker1.get_http_client() - worker2 = self.make_worker_hs( + mock_client2 = Mock(spec=["put_json"]) + mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client2, ) - mock_client2 = worker2.get_http_client() user = self.register_user("user2", "pass") token = self.login("user2", "pass") @@ -191,8 +105,8 @@ class FederationSenderTestCase(BaseStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() - mock_client2.reset_mock() + mock_client1.reset_mock() # type: ignore[attr-defined] + mock_client2.reset_mock() # type: ignore[attr-defined] self.create_and_send_event(room, UserID.from_string(user)) self.replicate() @@ -222,23 +136,29 @@ class FederationSenderTestCase(BaseStreamTestCase): """Test that using two federation sender workers correctly sends new typing EDUs. """ - worker1 = self.make_worker_hs( + mock_client1 = Mock(spec=["put_json"]) + mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client1, ) - mock_client1 = worker1.get_http_client() - worker2 = self.make_worker_hs( + mock_client2 = Mock(spec=["put_json"]) + mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client2, ) - mock_client2 = worker2.get_http_client() user = self.register_user("user3", "pass") token = self.login("user3", "pass") @@ -250,8 +170,8 @@ class FederationSenderTestCase(BaseStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() - mock_client2.reset_mock() + mock_client1.reset_mock() # type: ignore[attr-defined] + mock_client2.reset_mock() # type: ignore[attr-defined] self.get_success( typing_handler.started_typing( @@ -284,3 +204,32 @@ class FederationSenderTestCase(BaseStreamTestCase): self.assertTrue(sent_on_1) self.assertTrue(sent_on_2) + + def create_room_with_remote_server(self, user, token, remote_server="other_server"): + room = self.helper.create_room_as(user, tok=token) + store = self.hs.get_datastore() + federation = self.hs.get_handlers().federation_handler + + prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) + room_version = self.get_success(store.get_room_version(room)) + + factory = EventBuilderFactory(self.hs) + factory.hostname = remote_server + + user_id = UserID("user", remote_server).to_string() + + event_dict = { + "type": EventTypes.Member, + "state_key": user_id, + "content": {"membership": Membership.JOIN}, + "sender": user_id, + "room_id": room, + } + + builder = factory.for_room_version(room_version, event_dict) + join_event = self.get_success(builder.build(prev_event_ids)) + + self.get_success(federation.on_send_join_request(remote_server, join_event)) + self.replicate() + + return room diff --git a/tests/server.py b/tests/server.py index a5e57c52fa..b6e0b14e78 100644 --- a/tests/server.py +++ b/tests/server.py @@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def __init__(self): self.threadpool = ThreadPool(self) + self._tcp_callbacks = {} self._udp = [] lookups = self.lookups = {} @@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def getThreadPool(self): return self.threadpool + def add_tcp_client_callback(self, host, port, callback): + """Add a callback that will be invoked when we receive a connection + attempt to the given IP/port using `connectTCP`. + + Note that the callback gets run before we return the connection to the + client, which means callbacks cannot block while waiting for writes. + """ + self._tcp_callbacks[(host, port)] = callback + + def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): + """Fake L{IReactorTCP.connectTCP}. + """ + + conn = super().connectTCP( + host, port, factory, timeout=timeout, bindAddress=None + ) + + callback = self._tcp_callbacks.get((host, port)) + if callback: + callback() + + return conn + class ThreadPool: """ @@ -486,7 +510,7 @@ class FakeTransport(object): try: self.other.dataReceived(to_write) except Exception as e: - logger.warning("Exception writing to protocol: %s", e) + logger.exception("Exception writing to protocol: %s", e) return self.buffer = self.buffer[len(to_write) :] -- cgit 1.5.1 From 9006e125afa1de199577f79025913e7ad8ae9701 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 Jul 2020 15:47:27 +0100 Subject: Fix tests --- tests/handlers/test_typing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 1e6a53bf7f..5878f74175 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -138,10 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room - def get_current_users_in_room(room_id): + def get_users_in_room(room_id): return defer.succeed({str(u) for u in self.room_members}) - hs.get_state_handler().get_current_users_in_room = get_current_users_in_room + self.datastore.get_users_in_room = get_users_in_room self.datastore.get_user_directory_stream_pos.return_value = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam -- cgit 1.5.1 From 8c7d0f163d8247297dbcfd5f257b652ebe417fff Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Jul 2020 11:00:21 -0400 Subject: Allow accounts to be re-activated from the admin APIs. (#7847) --- changelog.d/7847.feature | 1 + docs/admin_api/user_admin_api.rst | 6 ++++- synapse/handlers/deactivate_account.py | 48 ++++++++++++++++++++-------------- synapse/rest/admin/users.py | 10 ++++++- tests/rest/admin/test_user.py | 47 +++++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7847.feature (limited to 'tests') diff --git a/changelog.d/7847.feature b/changelog.d/7847.feature new file mode 100644 index 0000000000..4b9a8d8569 --- /dev/null +++ b/changelog.d/7847.feature @@ -0,0 +1 @@ +Add the ability to re-activate an account from the admin API. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 7b030a6285..be05128b3e 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -91,10 +91,14 @@ Body parameters: - ``admin``, optional, defaults to ``false``. -- ``deactivated``, optional, defaults to ``false``. +- ``deactivated``, optional. If unspecified, deactivation state will be left + unchanged on existing accounts and set to ``false`` for new accounts. If the user already exists then optional parameters default to the current value. +In order to re-activate an account ``deactivated`` must be set to ``false``. If +users do not login via single-sign-on, a new ``password`` must be provided. + List Accounts ============= diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 3e3e6bd475..696d85b5f9 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process @@ -45,19 +46,20 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled - async def deactivate_account(self, user_id, erase_data, id_server=None): + async def deactivate_account( + self, user_id: str, erase_data: bool, id_server: Optional[str] = None + ) -> bool: """Deactivate a user's account Args: - user_id (str): ID of user to be deactivated - erase_data (bool): whether to GDPR-erase the user's data - id_server (str|None): Use the given identity server when unbinding + user_id: ID of user to be deactivated + erase_data: whether to GDPR-erase the user's data + id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). Returns: - Deferred[bool]: True if identity server supports removing - threepids, otherwise False. + True if identity server supports removing threepids, otherwise False. """ # FIXME: Theoretically there is a race here wherein user resets # password using threepid. @@ -134,11 +136,11 @@ class DeactivateAccountHandler(BaseHandler): return identity_server_supports_unbinding - async def _reject_pending_invites_for_user(self, user_id): + async def _reject_pending_invites_for_user(self, user_id: str): """Reject pending invites addressed to a given user ID. Args: - user_id (str): The user ID to reject pending invites for. + user_id: The user ID to reject pending invites for. """ user = UserID.from_string(user_id) pending_invites = await self.store.get_invited_rooms_for_local_user(user_id) @@ -166,22 +168,16 @@ class DeactivateAccountHandler(BaseHandler): room.room_id, ) - def _start_user_parting(self): + def _start_user_parting(self) -> None: """ Start the process that goes through the table of users pending deactivation, if it isn't already running. - - Returns: - None """ if not self._user_parter_running: run_as_background_process("user_parter_loop", self._user_parter_loop) - async def _user_parter_loop(self): + async def _user_parter_loop(self) -> None: """Loop that parts deactivated users from rooms - - Returns: - None """ self._user_parter_running = True logger.info("Starting user parter") @@ -198,11 +194,8 @@ class DeactivateAccountHandler(BaseHandler): finally: self._user_parter_running = False - async def _part_user(self, user_id): + async def _part_user(self, user_id: str) -> None: """Causes the given user_id to leave all the rooms they're joined to - - Returns: - None """ user = UserID.from_string(user_id) @@ -224,3 +217,18 @@ class DeactivateAccountHandler(BaseHandler): user_id, room_id, ) + + async def activate_account(self, user_id: str) -> None: + """ + Activate an account that was previously deactivated. + + This simply marks the user as activate in the database and does not + attempt to rejoin rooms, re-add threepids, etc. + + The user will also need a password hash set to actually login. + + Args: + user_id: ID of user to be deactivated + """ + # Mark the user as activate. + await self.store.set_user_deactivated_status(user_id, False) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index e4330c39d6..cc0bdfa5c9 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -239,6 +239,15 @@ class UserRestServletV2(RestServlet): await self.deactivate_account_handler.deactivate_account( target_user.to_string(), False ) + elif not deactivate and user["deactivated"]: + if "password" not in body: + raise SynapseError( + 400, "Must provide a password to re-activate an account." + ) + + await self.deactivate_account_handler.activate_account( + target_user.to_string() + ) user = await self.admin_handler.get_user(target_user) return 200, user @@ -254,7 +263,6 @@ class UserRestServletV2(RestServlet): admin = body.get("admin", None) user_type = body.get("user_type", None) displayname = body.get("displayname", None) - threepids = body.get("threepids", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: raise SynapseError(400, "Invalid user type") diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index cca5f548e6..f16eef15f7 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -857,6 +857,53 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + def test_reactivate_user(self): + """ + Test reactivating another user. + """ + + # Deactivate the user. + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": True}).encode(encoding="utf_8"), + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Attempt to reactivate the user (without a password). + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + # Reactivate the user. + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": False, "password": "foo"}).encode( + encoding="utf_8" + ), + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + def test_set_user_as_admin(self): """ Test setting the admin flag on a user. -- cgit 1.5.1 From 649a7ead5c4bd2d8b7c486ac1a68ce4e41d49767 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 Jul 2020 14:06:28 +0100 Subject: Add ability to run multiple pusher instances (#7855) This reuses the same scheme as federation sender sharding --- changelog.d/7855.feature | 1 + synapse/config/_base.py | 38 +++- synapse/config/_base.pyi | 5 + synapse/config/federation.py | 37 +--- synapse/config/push.py | 5 +- synapse/federation/sender/__init__.py | 16 +- synapse/federation/sender/per_destination_queue.py | 2 +- synapse/push/pusherpool.py | 78 +++++---- tests/replication/test_pusher_shard.py | 193 +++++++++++++++++++++ 9 files changed, 293 insertions(+), 82 deletions(-) create mode 100644 changelog.d/7855.feature create mode 100644 tests/replication/test_pusher_shard.py (limited to 'tests') diff --git a/changelog.d/7855.feature b/changelog.d/7855.feature new file mode 100644 index 0000000000..2b6a9f0e71 --- /dev/null +++ b/changelog.d/7855.feature @@ -0,0 +1 @@ +Add experimental support for running multiple pusher workers. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1391e5fc43..fd137853b1 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -19,9 +19,11 @@ import argparse import errno import os from collections import OrderedDict +from hashlib import sha256 from textwrap import dedent -from typing import Any, MutableMapping, Optional +from typing import Any, List, MutableMapping, Optional +import attr import yaml @@ -717,4 +719,36 @@ def find_config_files(search_paths): return config_files -__all__ = ["Config", "RootConfig"] +@attr.s +class ShardedWorkerHandlingConfig: + """Algorithm for choosing which instance is responsible for handling some + sharded work. + + For example, the federation senders use this to determine which instances + handles sending stuff to a given destination (which is used as the `key` + below). + """ + + instances = attr.ib(type=List[str]) + + def should_handle(self, instance_name: str, key: str) -> bool: + """Whether this instance is responsible for handling the given key. + """ + + # If multiple instances are not defined we always return true. + if not self.instances or len(self.instances) == 1: + return True + + # We shard by taking the hash, modulo it by the number of instances and + # then checking whether this instance matches the instance at that + # index. + # + # (Technically this introduces some bias and is not entirely uniform, + # but since the hash is so large the bias is ridiculously small). + dest_hash = sha256(key.encode("utf8")).digest() + dest_int = int.from_bytes(dest_hash, byteorder="little") + remainder = dest_int % (len(self.instances)) + return self.instances[remainder] == instance_name + + +__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 9e576060d4..eb911e8f9f 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -137,3 +137,8 @@ class Config: def read_config_files(config_files: List[str]): ... def find_config_files(search_paths: List[str]): ... + +class ShardedWorkerHandlingConfig: + instances: List[str] + def __init__(self, instances: List[str]) -> None: ... + def should_handle(self, instance_name: str, key: str) -> bool: ... diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 7782ab4c9d..82ff9664de 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -13,42 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hashlib import sha256 -from typing import List, Optional +from typing import Optional -import attr from netaddr import IPSet -from ._base import Config, ConfigError - - -@attr.s -class ShardedFederationSendingConfig: - """Algorithm for choosing which federation sender instance is responsible - for which destionation host. - """ - - instances = attr.ib(type=List[str]) - - def should_send_to(self, instance_name: str, destination: str) -> bool: - """Whether this instance is responsible for sending transcations for - the given host. - """ - - # If multiple federation senders are not defined we always return true. - if not self.instances or len(self.instances) == 1: - return True - - # We shard by taking the hash, modulo it by the number of federation - # senders and then checking whether this instance matches the instance - # at that index. - # - # (Technically this introduces some bias and is not entirely uniform, but - # since the hash is so large the bias is ridiculously small). - dest_hash = sha256(destination.encode("utf8")).digest() - dest_int = int.from_bytes(dest_hash, byteorder="little") - remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name +from ._base import Config, ConfigError, ShardedWorkerHandlingConfig class FederationConfig(Config): @@ -61,7 +30,7 @@ class FederationConfig(Config): self.send_federation = config.get("send_federation", True) federation_sender_instances = config.get("federation_sender_instances") or [] - self.federation_shard_config = ShardedFederationSendingConfig( + self.federation_shard_config = ShardedWorkerHandlingConfig( federation_sender_instances ) diff --git a/synapse/config/push.py b/synapse/config/push.py index 6f2b3a7faa..a1f3752c8a 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ShardedWorkerHandlingConfig class PushConfig(Config): @@ -24,6 +24,9 @@ class PushConfig(Config): push_config = config.get("push", {}) self.push_include_content = push_config.get("include_content", True) + pusher_instances = config.get("pusher_instances") or [] + self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) + # There was a a 'redact_content' setting but mistakenly read from the # 'email'section'. Check for the flag in the 'push' section, and log, # but do not honour it to avoid nasty surprises when people upgrade. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4b63a0755f..b328a4df09 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -197,7 +197,7 @@ class FederationSender(object): destinations = { d for d in destinations - if self._federation_shard_config.should_send_to( + if self._federation_shard_config.should_handle( self._instance_name, d ) } @@ -335,7 +335,7 @@ class FederationSender(object): d for d in domains if d != self.server_name - and self._federation_shard_config.should_send_to(self._instance_name, d) + and self._federation_shard_config.should_handle(self._instance_name, d) ] if not domains: return @@ -441,7 +441,7 @@ class FederationSender(object): for destination in destinations: if destination == self.server_name: continue - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): continue @@ -460,7 +460,7 @@ class FederationSender(object): if destination == self.server_name: continue - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): continue @@ -486,7 +486,7 @@ class FederationSender(object): logger.info("Not sending EDU to ourselves") return - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): return @@ -507,7 +507,7 @@ class FederationSender(object): edu: edu to send key: clobbering key for this edu """ - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, edu.destination ): return @@ -523,7 +523,7 @@ class FederationSender(object): logger.warning("Not sending device update to ourselves") return - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): return @@ -541,7 +541,7 @@ class FederationSender(object): logger.warning("Not waking up ourselves") return - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): return diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 6402136e8a..3436741783 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -78,7 +78,7 @@ class PerDestinationQueue(object): self._federation_shard_config = hs.config.federation.federation_shard_config self._should_send_on_this_instance = True - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): # We don't raise an exception here to avoid taking out any other diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index f6a5458681..2456f12f46 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -15,13 +15,12 @@ # limitations under the License. import logging -from collections import defaultdict -from threading import Lock -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Dict, Union + +from prometheus_client import Gauge from twisted.internet import defer -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher @@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher from synapse.push.pusher import PusherFactory from synapse.util.async_helpers import concurrently_execute +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) +synapse_pushers = Gauge( + "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"] +) + + class PusherPool: """ The pusher pool. This is responsible for dispatching notifications of new events to @@ -47,36 +55,20 @@ class PusherPool: Pusher.on_new_receipts are not expected to return deferreds. """ - def __init__(self, _hs): - self.hs = _hs - self.pusher_factory = PusherFactory(_hs) - self._should_start_pushers = _hs.config.start_pushers + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.pusher_factory = PusherFactory(hs) + self._should_start_pushers = hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() + # We shard the handling of push notifications by user ID. + self._pusher_shard_config = hs.config.push.pusher_shard_config + self._instance_name = hs.get_instance_name() + # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] - # a lock for the pushers dict, since `count_pushers` is called from an different - # and we otherwise get concurrent modification errors - self._pushers_lock = Lock() - - def count_pushers(): - results = defaultdict(int) # type: Dict[Tuple[str, str], int] - with self._pushers_lock: - for pushers in self.pushers.values(): - for pusher in pushers.values(): - k = (type(pusher).__name__, pusher.app_id) - results[k] += 1 - return results - - LaterGauge( - name="synapse_pushers", - desc="the number of active pushers", - labels=["kind", "app_id"], - caller=count_pushers, - ) - def start(self): """Starts the pushers off in a background process. """ @@ -104,6 +96,7 @@ class PusherPool: Returns: Deferred[EmailPusher|HttpPusher] """ + time_now_msec = self.clock.time_msec() # we try to create the pusher just to validate the config: it @@ -176,6 +169,9 @@ class PusherPool: access_tokens (Iterable[int]): access token *ids* to remove pushers for """ + if not self._pusher_shard_config.should_handle(self._instance_name, user_id): + return + tokens = set(access_tokens) for p in (yield self.store.get_pushers_by_user_id(user_id)): if p["access_token"] in tokens: @@ -237,6 +233,9 @@ class PusherPool: if not self._should_start_pushers: return + if not self._pusher_shard_config.should_handle(self._instance_name, user_id): + return + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None @@ -275,6 +274,11 @@ class PusherPool: Returns: Deferred[EmailPusher|HttpPusher] """ + if not self._pusher_shard_config.should_handle( + self._instance_name, pusherdict["user_name"] + ): + return + try: p = self.pusher_factory.create_pusher(pusherdict) except PusherConfigException as e: @@ -298,11 +302,12 @@ class PusherPool: appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) - with self._pushers_lock: - byuser = self.pushers.setdefault(pusherdict["user_name"], {}) - if appid_pushkey in byuser: - byuser[appid_pushkey].on_stop() - byuser[appid_pushkey] = p + byuser = self.pushers.setdefault(pusherdict["user_name"], {}) + if appid_pushkey in byuser: + byuser[appid_pushkey].on_stop() + byuser[appid_pushkey] = p + + synapse_pushers.labels(type(p).__name__, p.app_id).inc() # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to @@ -330,9 +335,10 @@ class PusherPool: if appid_pushkey in byuser: logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) - byuser[appid_pushkey].on_stop() - with self._pushers_lock: - del byuser[appid_pushkey] + pusher = byuser.pop(appid_pushkey) + pusher.on_stop() + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() yield self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py new file mode 100644 index 0000000000..2bdc6edbb1 --- /dev/null +++ b/tests/replication/test_pusher_shard.py @@ -0,0 +1,193 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import Mock + +from twisted.internet import defer + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + + +class PusherShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks pusher sharding works + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Register a user who sends a message that we'll get notified about + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + def default_config(self): + conf = super().default_config() + conf["start_pushers"] = False + return conf + + def _create_pusher_and_send_msg(self, localpart): + # Create a user that will get push notifications + user_id = self.register_user(localpart, "pass") + access_token = self.login(localpart, "pass") + + # Register a pusher + user_dict = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_dict["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "https://push.example.com/push"}, + ) + ) + + self.pump() + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room, user=self.other_user_id, tok=self.other_access_token + ) + + # The other user sends some messages + response = self.helper.send(room, body="Hi!", tok=self.other_access_token) + event_id = response["event_id"] + + return event_id + + def test_send_push_single_worker(self): + """Test that registration works when using a pusher worker. + """ + http_client_mock = Mock(spec_set=["post_json_get_json"]) + http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + {"start_pushers": True}, + proxied_http_client=http_client_mock, + ) + + event_id = self._create_pusher_and_send_msg("user") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock.post_json_get_json.assert_called_once() + self.assertEqual( + http_client_mock.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) + + def test_send_push_multiple_workers(self): + """Test that registration works when using sharded pusher workers. + """ + http_client_mock1 = Mock(spec_set=["post_json_get_json"]) + http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + { + "start_pushers": True, + "worker_name": "pusher1", + "pusher_instances": ["pusher1", "pusher2"], + }, + proxied_http_client=http_client_mock1, + ) + + http_client_mock2 = Mock(spec_set=["post_json_get_json"]) + http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + { + "start_pushers": True, + "worker_name": "pusher2", + "pusher_instances": ["pusher1", "pusher2"], + }, + proxied_http_client=http_client_mock2, + ) + + # We choose a user name that we know should go to pusher1. + event_id = self._create_pusher_and_send_msg("user2") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock1.post_json_get_json.assert_called_once() + http_client_mock2.post_json_get_json.assert_not_called() + self.assertEqual( + http_client_mock1.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock1.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) + + http_client_mock1.post_json_get_json.reset_mock() + http_client_mock2.post_json_get_json.reset_mock() + + # Now we choose a user name that we know should go to pusher2. + event_id = self._create_pusher_and_send_msg("user4") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock1.post_json_get_json.assert_not_called() + http_client_mock2.post_json_get_json.assert_called_once() + self.assertEqual( + http_client_mock2.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock2.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) -- cgit 1.5.1 From fff483ea96160912ee4e9f5f3f743b86a933058f Mon Sep 17 00:00:00 2001 From: Michael Albert <37796947+awesome-michael@users.noreply.github.com> Date: Thu, 16 Jul 2020 22:43:23 +0200 Subject: Add admin endpoint to get members in a room. (#7842) --- changelog.d/7842.feature | 1 + docs/admin_api/rooms.md | 34 ++++++++++++++++++++++++++++++- synapse/rest/admin/__init__.py | 2 ++ synapse/rest/admin/rooms.py | 25 +++++++++++++++++++++++ tests/rest/admin/test_room.py | 46 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7842.feature (limited to 'tests') diff --git a/changelog.d/7842.feature b/changelog.d/7842.feature new file mode 100644 index 0000000000..727deb01c9 --- /dev/null +++ b/changelog.d/7842.feature @@ -0,0 +1 @@ +Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 3f26adc16c..15b83e9824 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -319,11 +319,43 @@ Response: } ``` +# Room Members API + +The Room Members admin API allows server admins to get a list of all members of a room. + +The response includes the following fields: + +* `members` - A list of all the members that are present in the room, represented by their ids. +* `total` - Total number of members in the room. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms//members + +{} +``` + +Response: + +``` +{ + "members": [ + "@foo:matrix.org", + "@bar:matrix.org", + "@foobar:matrix.org + ], + "total": 3 +} +``` + # Delete Room API The Delete Room admin API allows server admins to remove rooms from server and block these rooms. -It is a combination and improvement of "[Shutdown room](shutdown_room.md)" +It is a combination and improvement of "[Shutdown room](shutdown_room.md)" and "[Purge room](purge_room.md)" API. Shuts down a room. Moves all local users and room aliases automatically to a diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index dc373bc5a3..1c88c93f38 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -38,6 +38,7 @@ from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, + RoomMembersRestServlet, RoomRestServlet, ShutdownRoomRestServlet, ) @@ -201,6 +202,7 @@ def register_servlets(hs, http_server): register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) + RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 544be47060..b8c95d045a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -231,6 +231,31 @@ class RoomRestServlet(RestServlet): return 200, ret +class RoomMembersRestServlet(RestServlet): + """ + Get members list of a room. + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)/members") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + await assert_requester_is_admin(self.auth, request) + + ret = await self.store.get_room(room_id) + if not ret: + raise NotFoundError("Room not found") + + members = await self.store.get_users_in_room(room_id) + ret = {"members": members, "total": len(members)} + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index a80537c4fc..946f06d151 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1136,6 +1136,52 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id_1, channel.json_body["room_id"]) + def test_room_members(self): + """Test that room members can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Have another user join the room + user_1 = self.register_user("foo", "pass") + user_tok_1 = self.login("foo", "pass") + self.helper.join(room_id_1, user_1, tok=user_tok_1) + + # Have another user join the room + user_2 = self.register_user("bar", "pass") + user_tok_2 = self.login("bar", "pass") + self.helper.join(room_id_1, user_2, tok=user_tok_2) + self.helper.join(room_id_2, user_2, tok=user_tok_2) + + # Have another user join the room + user_3 = self.register_user("foobar", "pass") + user_tok_3 = self.login("foobar", "pass") + self.helper.join(room_id_2, user_3, tok=user_tok_3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 6fca1b3506e31e6864e1dc18046f1962813f14e2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Jul 2020 07:08:30 -0400 Subject: Convert _base, profile, and _receipts handlers to async/await (#7860) --- changelog.d/7860.misc | 1 + synapse/handlers/_base.py | 7 ++--- synapse/handlers/message.py | 8 ++++-- synapse/handlers/profile.py | 63 ++++++++++++++++++------------------------ synapse/handlers/receipts.py | 16 ++++------- tests/handlers/test_profile.py | 17 ++++++++---- 6 files changed, 53 insertions(+), 59 deletions(-) create mode 100644 changelog.d/7860.misc (limited to 'tests') diff --git a/changelog.d/7860.misc b/changelog.d/7860.misc new file mode 100644 index 0000000000..fdd48b955c --- /dev/null +++ b/changelog.d/7860.misc @@ -0,0 +1 @@ +Convert _base, profile, and _receipts handlers to async/await. diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6a4944467a..ba2bf99800 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - import synapse.state import synapse.storage import synapse.types @@ -66,8 +64,7 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() - @defer.inlineCallbacks - def ratelimit(self, requester, update=True, is_admin_redaction=False): + async def ratelimit(self, requester, update=True, is_admin_redaction=False): """Ratelimits requests. Args: @@ -99,7 +96,7 @@ class BaseHandler(object): burst_count = self._rc_message.burst_count # Check if there is a per user override in the DB. - override = yield self.store.get_ratelimit_for_user(user_id) + override = await self.store.get_ratelimit_for_user(user_id) if override: # If overridden with a null Hz then ratelimiting has been entirely # disabled for the user diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index da206e1ec1..c47764a4ce 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -488,11 +488,15 @@ class EventCreationHandler(object): try: if "displayname" not in content: - displayname = yield profile.get_displayname(target) + displayname = yield defer.ensureDeferred( + profile.get_displayname(target) + ) if displayname is not None: content["displayname"] = displayname if "avatar_url" not in content: - avatar_url = yield profile.get_avatar_url(target) + avatar_url = yield defer.ensureDeferred( + profile.get_avatar_url(target) + ) if avatar_url is not None: content["avatar_url"] = avatar_url except Exception as e: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 4b1e3073a8..31a2e5ea18 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( AuthError, Codes, @@ -54,16 +52,15 @@ class BaseProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() - @defer.inlineCallbacks - def get_profile(self, user_id): + async def get_profile(self, user_id): target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -74,7 +71,7 @@ class BaseProfileHandler(BaseHandler): return {"displayname": displayname, "avatar_url": avatar_url} else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": user_id}, @@ -86,8 +83,7 @@ class BaseProfileHandler(BaseHandler): except HttpResponseException as e: raise e.to_synapse_error() - @defer.inlineCallbacks - def get_profile_from_cache(self, user_id): + async def get_profile_from_cache(self, user_id): """Get the profile information from our local cache. If the user is ours then the profile information will always be corect. Otherwise, it may be out of date/missing. @@ -95,10 +91,10 @@ class BaseProfileHandler(BaseHandler): target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -108,14 +104,13 @@ class BaseProfileHandler(BaseHandler): return {"displayname": displayname, "avatar_url": avatar_url} else: - profile = yield self.store.get_from_remote_profile_cache(user_id) + profile = await self.store.get_from_remote_profile_cache(user_id) return profile or {} - @defer.inlineCallbacks - def get_displayname(self, target_user): + async def get_displayname(self, target_user): if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) except StoreError as e: @@ -126,7 +121,7 @@ class BaseProfileHandler(BaseHandler): return displayname else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": target_user.to_string(), "field": "displayname"}, @@ -189,11 +184,10 @@ class BaseProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - @defer.inlineCallbacks - def get_avatar_url(self, target_user): + async def get_avatar_url(self, target_user): if self.hs.is_mine(target_user): try: - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -203,7 +197,7 @@ class BaseProfileHandler(BaseHandler): return avatar_url else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": target_user.to_string(), "field": "avatar_url"}, @@ -253,8 +247,7 @@ class BaseProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - @defer.inlineCallbacks - def on_profile_query(self, args): + async def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -264,12 +257,12 @@ class BaseProfileHandler(BaseHandler): response = {} try: if just_field is None or just_field == "displayname": - response["displayname"] = yield self.store.get_profile_displayname( + response["displayname"] = await self.store.get_profile_displayname( user.localpart ) if just_field is None or just_field == "avatar_url": - response["avatar_url"] = yield self.store.get_profile_avatar_url( + response["avatar_url"] = await self.store.get_profile_avatar_url( user.localpart ) except StoreError as e: @@ -304,8 +297,7 @@ class BaseProfileHandler(BaseHandler): "Failed to update join event for room %s - %s", room_id, str(e) ) - @defer.inlineCallbacks - def check_profile_query_allowed(self, target_user, requester=None): + async def check_profile_query_allowed(self, target_user, requester=None): """Checks whether a profile query is allowed. If the 'require_auth_for_profile_requests' config flag is set to True and a 'requester' is provided, the query is only allowed if the two users @@ -337,8 +329,8 @@ class BaseProfileHandler(BaseHandler): return try: - requester_rooms = yield self.store.get_rooms_for_user(requester.to_string()) - target_user_rooms = yield self.store.get_rooms_for_user( + requester_rooms = await self.store.get_rooms_for_user(requester.to_string()) + target_user_rooms = await self.store.get_rooms_for_user( target_user.to_string() ) @@ -371,25 +363,24 @@ class MasterProfileHandler(BaseProfileHandler): "Update remote profile", self._update_remote_profile_cache ) - @defer.inlineCallbacks - def _update_remote_profile_cache(self): + async def _update_remote_profile_cache(self): """Called periodically to check profiles of remote users we haven't checked in a while. """ - entries = yield self.store.get_remote_profile_cache_entries_that_expire( + entries = await self.store.get_remote_profile_cache_entries_that_expire( last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS ) for user_id, displayname, avatar_url in entries: - is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( + is_subscribed = await self.store.is_subscribed_remote_profile_for_user( user_id ) if not is_subscribed: - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) continue try: - profile = yield self.federation.make_query( + profile = await self.federation.make_query( destination=get_domain_from_id(user_id), query_type="profile", args={"user_id": user_id}, @@ -398,7 +389,7 @@ class MasterProfileHandler(BaseProfileHandler): except Exception: logger.exception("Failed to get avatar_url") - yield self.store.update_remote_profile_cache( + await self.store.update_remote_profile_cache( user_id, displayname, avatar_url ) continue @@ -407,4 +398,4 @@ class MasterProfileHandler(BaseProfileHandler): new_avatar = profile.get("avatar_url") # We always hit update to update the last_check timestamp - yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar) + await self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 8bc100db42..f922d8a545 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.handlers._base import BaseHandler from synapse.types import ReadReceipt, get_domain_from_id from synapse.util.async_helpers import maybe_awaitable @@ -129,15 +127,14 @@ class ReceiptEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events(self, from_key, room_ids, **kwargs): from_key = int(from_key) - to_key = yield self.get_current_key() + to_key = self.get_current_key() if from_key == to_key: return [], to_key - events = yield self.store.get_linearized_receipts_for_rooms( + events = await self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key ) @@ -146,8 +143,7 @@ class ReceiptEventSource(object): def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): + async def get_pagination_rows(self, user, config, key): to_key = int(config.from_key) if config.to_key: @@ -155,8 +151,8 @@ class ReceiptEventSource(object): else: from_key = None - room_ids = yield self.store.get_rooms_for_user(user.to_string()) - events = yield self.store.get_linearized_receipts_for_rooms( + room_ids = await self.store.get_rooms_for_user(user.to_string()) + events = await self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 29dd7d9c6e..4f1347cd25 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -72,7 +72,9 @@ class ProfileTestCase(unittest.TestCase): def test_get_my_name(self): yield self.store.set_profile_displayname(self.frank.localpart, "Frank") - displayname = yield self.handler.get_displayname(self.frank) + displayname = yield defer.ensureDeferred( + self.handler.get_displayname(self.frank) + ) self.assertEquals("Frank", displayname) @@ -140,7 +142,9 @@ class ProfileTestCase(unittest.TestCase): {"displayname": "Alice"} ) - displayname = yield self.handler.get_displayname(self.alice) + displayname = yield defer.ensureDeferred( + self.handler.get_displayname(self.alice) + ) self.assertEquals(displayname, "Alice") self.mock_federation.make_query.assert_called_with( @@ -155,8 +159,10 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile("caroline") yield self.store.set_profile_displayname("caroline", "Caroline") - response = yield self.query_handlers["profile"]( - {"user_id": "@caroline:test", "field": "displayname"} + response = yield defer.ensureDeferred( + self.query_handlers["profile"]( + {"user_id": "@caroline:test", "field": "displayname"} + ) ) self.assertEquals({"displayname": "Caroline"}, response) @@ -166,8 +172,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) - - avatar_url = yield self.handler.get_avatar_url(self.frank) + avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) self.assertEquals("http://my.server/me.png", avatar_url) -- cgit 1.5.1 From 6b3ac3b8cddda9911f42a08a0dcefc4a3386ff51 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Jul 2020 07:09:25 -0400 Subject: Convert device handler to async/await (#7871) --- changelog.d/7871.misc | 1 + synapse/handlers/device.py | 241 +++++++++++++++++----------------------- synapse/util/distributor.py | 28 ++++- tests/handlers/test_device.py | 13 +-- tests/handlers/test_e2e_keys.py | 10 +- tests/test_federation.py | 35 +++--- 6 files changed, 162 insertions(+), 166 deletions(-) create mode 100644 changelog.d/7871.misc (limited to 'tests') diff --git a/changelog.d/7871.misc b/changelog.d/7871.misc new file mode 100644 index 0000000000..4d398a9f3a --- /dev/null +++ b/changelog.d/7871.misc @@ -0,0 +1 @@ +Convert device handler to async/await. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 31346b56c3..f947aa1627 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Optional - -from twisted.internet import defer +from typing import Any, Dict, List, Optional from synapse.api import errors from synapse.api.constants import EventTypes @@ -57,21 +55,20 @@ class DeviceWorkerHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() @trace - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): + async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: """ Retrieve the given user's devices Args: - user_id (str): + user_id: The user ID to query for devices. Returns: - defer.Deferred: list[dict[str, X]]: info on each device + info on each device """ set_tag("user_id", user_id) - device_map = yield self.store.get_devices_by_user(user_id) + device_map = await self.store.get_devices_by_user(user_id) - ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None) + ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None) devices = list(device_map.values()) for device in devices: @@ -81,24 +78,23 @@ class DeviceWorkerHandler(BaseHandler): return devices @trace - @defer.inlineCallbacks - def get_device(self, user_id, device_id): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """ Retrieve the given device Args: - user_id (str): - device_id (str): + user_id: The user to get the device from + device_id: The device to fetch. Returns: - defer.Deferred: dict[str, X]: info on the device + info on the device Raises: errors.NotFoundError: if the device was not found """ try: - device = yield self.store.get_device(user_id, device_id) + device = await self.store.get_device(user_id, device_id) except errors.StoreError: raise errors.NotFoundError - ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) + ips = await self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) set_tag("device", device) @@ -106,10 +102,9 @@ class DeviceWorkerHandler(BaseHandler): return device - @measure_func("device.get_user_ids_changed") @trace - @defer.inlineCallbacks - def get_user_ids_changed(self, user_id, from_token): + @measure_func("device.get_user_ids_changed") + async def get_user_ids_changed(self, user_id, from_token): """Get list of users that have had the devices updated, or have newly joined a room, that `user_id` may be interested in. @@ -120,13 +115,13 @@ class DeviceWorkerHandler(BaseHandler): set_tag("user_id", user_id) set_tag("from_token", from_token) - now_room_key = yield self.store.get_room_events_max_id() + now_room_key = await self.store.get_room_events_max_id() - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) # First we check if any devices have changed for users that we share # rooms with. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) @@ -135,14 +130,14 @@ class DeviceWorkerHandler(BaseHandler): # Always tell the user about their own devices tracked_users.add(user_id) - changed = yield self.store.get_users_whose_devices_changed( + changed = await self.store.get_users_whose_devices_changed( from_token.device_list_key, tracked_users ) # Then work out if any users have since joined rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) - member_events = yield self.store.get_membership_changes_for_user( + member_events = await self.store.get_membership_changes_for_user( user_id, from_token.room_key, now_room_key ) rooms_changed.update(event.room_id for event in member_events) @@ -152,7 +147,7 @@ class DeviceWorkerHandler(BaseHandler): possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = yield self.store.get_current_state_ids(room_id) + current_state_ids = await self.store.get_current_state_ids(room_id) # The user may have left the room # TODO: Check if they actually did or if we were just invited. @@ -166,7 +161,7 @@ class DeviceWorkerHandler(BaseHandler): # Fetch the current state at the time. try: - event_ids = yield self.store.get_forward_extremeties_for_room( + event_ids = await self.store.get_forward_extremeties_for_room( room_id, stream_ordering=stream_ordering ) except errors.StoreError: @@ -192,7 +187,7 @@ class DeviceWorkerHandler(BaseHandler): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) + prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. @@ -238,11 +233,10 @@ class DeviceWorkerHandler(BaseHandler): return result - @defer.inlineCallbacks - def on_federation_query_user_devices(self, user_id): - stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") - self_signing_key = yield self.store.get_e2e_cross_signing_key( + async def on_federation_query_user_devices(self, user_id): + stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id) + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") + self_signing_key = await self.store.get_e2e_cross_signing_key( user_id, "self_signing" ) @@ -271,8 +265,7 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - @defer.inlineCallbacks - def check_device_registered( + async def check_device_registered( self, user_id, device_id, initial_device_display_name=None ): """ @@ -290,13 +283,13 @@ class DeviceHandler(DeviceWorkerHandler): str: device id (generated if none was supplied) """ if device_id is not None: - new_device = yield self.store.store_device( + new_device = await self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) return device_id # if the device id is not specified, we'll autogen one, but loop a few @@ -304,33 +297,29 @@ class DeviceHandler(DeviceWorkerHandler): attempts = 0 while attempts < 5: device_id = stringutils.random_string(10).upper() - new_device = yield self.store.store_device( + new_device = await self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) return device_id attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @trace - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): + async def delete_device(self, user_id: str, device_id: str) -> None: """ Delete the given device Args: - user_id (str): - device_id (str): - - Returns: - defer.Deferred: + user_id: The user to delete the device from. + device_id: The device to delete. """ try: - yield self.store.delete_device(user_id, device_id) + await self.store.delete_device(user_id, device_id) except errors.StoreError as e: if e.code == 404: # no match @@ -342,49 +331,40 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - yield defer.ensureDeferred( - self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) + await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) @trace - @defer.inlineCallbacks - def delete_all_devices_for_user(self, user_id, except_device_id=None): + async def delete_all_devices_for_user( + self, user_id: str, except_device_id: Optional[str] = None + ) -> None: """Delete all of the user's devices Args: - user_id (str): - except_device_id (str|None): optional device id which should not - be deleted - - Returns: - defer.Deferred: + user_id: The user to remove all devices from + except_device_id: optional device id which should not be deleted """ - device_map = yield self.store.get_devices_by_user(user_id) + device_map = await self.store.get_devices_by_user(user_id) device_ids = list(device_map) if except_device_id is not None: device_ids = [d for d in device_ids if d != except_device_id] - yield self.delete_devices(user_id, device_ids) + await self.delete_devices(user_id, device_ids) - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """ Delete several devices Args: - user_id (str): - device_ids (List[str]): The list of device IDs to delete - - Returns: - defer.Deferred: + user_id: The user to delete devices from. + device_ids: The list of device IDs to delete """ try: - yield self.store.delete_devices(user_id, device_ids) + await self.store.delete_devices(user_id, device_ids) except errors.StoreError as e: if e.code == 404: # no match @@ -397,28 +377,22 @@ class DeviceHandler(DeviceWorkerHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield defer.ensureDeferred( - self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device( + await self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id ) - yield self.notify_device_update(user_id, device_ids) + await self.notify_device_update(user_id, device_ids) - @defer.inlineCallbacks - def update_device(self, user_id, device_id, content): + async def update_device(self, user_id: str, device_id: str, content: dict) -> None: """ Update the given device Args: - user_id (str): - device_id (str): - content (dict): body of update request - - Returns: - defer.Deferred: + user_id: The user to update devices of. + device_id: The device to update. + content: body of update request """ # Reject a new displayname which is too long. @@ -431,10 +405,10 @@ class DeviceHandler(DeviceWorkerHandler): ) try: - yield self.store.update_device( + await self.store.update_device( user_id, device_id, new_display_name=new_display_name ) - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: if e.code == 404: raise errors.NotFoundError() @@ -443,12 +417,11 @@ class DeviceHandler(DeviceWorkerHandler): @trace @measure_func("notify_device_update") - @defer.inlineCallbacks - def notify_device_update(self, user_id, device_ids): + async def notify_device_update(self, user_id, device_ids): """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. """ - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) @@ -459,7 +432,7 @@ class DeviceHandler(DeviceWorkerHandler): set_tag("target_hosts", hosts) - position = yield self.store.add_device_change_to_streams( + position = await self.store.add_device_change_to_streams( user_id, device_ids, list(hosts) ) @@ -468,11 +441,11 @@ class DeviceHandler(DeviceWorkerHandler): "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - yield self.notifier.on_new_event( + self.notifier.on_new_event( "device_list_key", position, users=[user_id], rooms=room_ids ) @@ -484,29 +457,29 @@ class DeviceHandler(DeviceWorkerHandler): self.federation_sender.send_device_messages(host) log_kv({"message": "sent device update to host", "host": host}) - @defer.inlineCallbacks - def notify_user_signature_update(self, from_user_id, user_ids): + async def notify_user_signature_update( + self, from_user_id: str, user_ids: List[str] + ) -> None: """Notify a user that they have made new signatures of other users. Args: - from_user_id (str): the user who made the signature - user_ids (list[str]): the users IDs that have new signatures + from_user_id: the user who made the signature + user_ids: the users IDs that have new signatures """ - position = yield self.store.add_user_signature_change_to_streams( + position = await self.store.add_user_signature_change_to_streams( from_user_id, user_ids ) self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) - @defer.inlineCallbacks - def user_left_room(self, user, room_id): + async def user_left_room(self, user, room_id): user_id = user.to_string() - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We no longer share rooms with this user, so we'll no longer # receive device updates. Mark this in DB. - yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) def _update_device_from_client_ips(device, client_ips): @@ -549,8 +522,7 @@ class DeviceListUpdater(object): ) @trace - @defer.inlineCallbacks - def incoming_device_list_update(self, origin, edu_content): + async def incoming_device_list_update(self, origin, edu_content): """Called on incoming device list update from federation. Responsible for parsing the EDU and adding to pending updates list. """ @@ -583,7 +555,7 @@ class DeviceListUpdater(object): ) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -608,14 +580,13 @@ class DeviceListUpdater(object): (device_id, stream_id, prev_ids, edu_content) ) - yield self._handle_device_updates(user_id) + await self._handle_device_updates(user_id) @measure_func("_incoming_device_list_update") - @defer.inlineCallbacks - def _handle_device_updates(self, user_id): + async def _handle_device_updates(self, user_id): "Actually handle pending updates." - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -632,7 +603,7 @@ class DeviceListUpdater(object): # Given a list of updates we check if we need to resync. This # happens if we've missed updates. - resync = yield self._need_to_do_resync(user_id, pending_updates) + resync = await self._need_to_do_resync(user_id, pending_updates) if logger.isEnabledFor(logging.INFO): logger.info( @@ -643,16 +614,16 @@ class DeviceListUpdater(object): ) if resync: - yield self.user_device_resync(user_id) + await self.user_device_resync(user_id) else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: - yield self.store.update_remote_device_list_cache_entry( + await self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id ) - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user_id, [device_id for device_id, _, _, _ in pending_updates] ) @@ -660,14 +631,13 @@ class DeviceListUpdater(object): stream_id for _, stream_id, _, _ in pending_updates ) - @defer.inlineCallbacks - def _need_to_do_resync(self, user_id, updates): + async def _need_to_do_resync(self, user_id, updates): """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ seen_updates = self._seen_updates.get(user_id, set()) - extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id) + extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) logger.debug("Current extremity for %r: %r", user_id, extremity) @@ -692,8 +662,7 @@ class DeviceListUpdater(object): return False @trace - @defer.inlineCallbacks - def _maybe_retry_device_resync(self): + async def _maybe_retry_device_resync(self): """Retry to resync device lists that are out of sync, except if another retry is in progress. """ @@ -705,12 +674,12 @@ class DeviceListUpdater(object): # we don't send too many requests. self._resync_retry_in_progress = True # Get all of the users that need resyncing. - need_resync = yield self.store.get_user_ids_requiring_device_list_resync() + need_resync = await self.store.get_user_ids_requiring_device_list_resync() # Iterate over the set of user IDs. for user_id in need_resync: try: # Try to resync the current user's devices list. - result = yield self.user_device_resync( + result = await self.user_device_resync( user_id=user_id, mark_failed_as_stale=False, ) @@ -734,16 +703,17 @@ class DeviceListUpdater(object): # Allow future calls to retry resyncinc out of sync device lists. self._resync_retry_in_progress = False - @defer.inlineCallbacks - def user_device_resync(self, user_id, mark_failed_as_stale=True): + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[dict]: """Fetches all devices for a user and updates the device cache with them. Args: - user_id (str): The user's id whose device_list will be updated. - mark_failed_as_stale (bool): Whether to mark the user's device list as stale + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale if the attempt to resync failed. Returns: - Deferred[dict]: a dict with device info as under the "devices" in the result of this + A dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid """ @@ -752,12 +722,12 @@ class DeviceListUpdater(object): # Fetch all devices for the user. origin = get_domain_from_id(user_id) try: - result = yield self.federation.query_user_devices(origin, user_id) + result = await self.federation.query_user_devices(origin, user_id) except NotRetryingDestination: if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) return except (RequestSendFailed, HttpResponseException) as e: @@ -768,7 +738,7 @@ class DeviceListUpdater(object): if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list @@ -792,7 +762,7 @@ class DeviceListUpdater(object): if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) return log_kv({"result": result}) @@ -833,25 +803,24 @@ class DeviceListUpdater(object): stream_id, ) - yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) + await self.store.update_remote_device_list_cache(user_id, devices, stream_id) device_ids = [device["device_id"] for device in devices] # Handle cross-signing keys. - cross_signing_device_ids = yield self.process_cross_signing_key_update( + cross_signing_device_ids = await self.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + cross_signing_device_ids - yield self.device_handler.notify_device_update(user_id, device_ids) + await self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given # point. self._seen_updates[user_id] = {stream_id} - defer.returnValue(result) + return result - @defer.inlineCallbacks - def process_cross_signing_key_update( + async def process_cross_signing_key_update( self, user_id: str, master_key: Optional[Dict[str, Any]], @@ -872,14 +841,14 @@ class DeviceListUpdater(object): device_ids = [] if master_key: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) _, verify_key = get_verify_key_from_cross_signing_key(master_key) # verify_key is a VerifyKey from signedjson, which uses # .version to denote the portion of the key ID after the # algorithm and colon, which is the device ID device_ids.append(verify_key.version) if self_signing_key: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index da20523b70..22a857a306 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import logging from twisted.internet import defer +from twisted.internet.defer import Deferred, fail, succeed +from twisted.python import failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -79,6 +81,28 @@ class Distributor(object): run_as_background_process(name, self.signals[name].fire, *args, **kwargs) +def maybeAwaitableDeferred(f, *args, **kw): + """ + Invoke a function that may or may not return a Deferred or an Awaitable. + + This is a modified version of twisted.internet.defer.maybeDeferred. + """ + try: + result = f(*args, **kw) + except Exception: + return fail(failure.Failure(captureVars=Deferred.debug)) + + if isinstance(result, Deferred): + return result + # Handle the additional case of an awaitable being returned. + elif inspect.isawaitable(result): + return defer.ensureDeferred(result) + elif isinstance(result, failure.Failure): + return fail(result) + else: + return succeed(result) + + class Signal(object): """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -122,7 +146,7 @@ class Signal(object): ), ) - return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) + return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb) deferreds = [run_in_background(do, o) for o in self.observers] diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 62b47f6574..6aa322bf3a 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.get_success(self.handler.delete_device(user1, "abc")) # check the device was deleted - res = self.handler.get_device(user1, "abc") - self.pump() - self.assertIsInstance( - self.failureResultOf(res).value, synapse.api.errors.NotFoundError + self.get_failure( + self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError ) # we'd like to check the access token was invalidated, but that's a @@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): def test_update_unknown_device(self): update = {"display_name": "new_display"} - res = self.handler.update_device("user_id", "unknown_device_id", update) - self.pump() - self.assertIsInstance( - self.failureResultOf(res).value, synapse.api.errors.NotFoundError + self.get_failure( + self.handler.update_device("user_id", "unknown_device_id", update), + synapse.api.errors.NotFoundError, ) def _record_users(self): diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index cdd093ffa8..210ddcbb88 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -334,10 +334,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): res = None try: - yield self.hs.get_device_handler().check_device_registered( - user_id=local_user, - device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", - initial_device_display_name="new display name", + yield defer.ensureDeferred( + self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ) ) except errors.SynapseError as e: res = e.code diff --git a/tests/test_federation.py b/tests/test_federation.py index 89dcc58b99..87a16d7d7a 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -173,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.homeserver.get_datastore() - store.get_rooms_for_user = Mock(return_value=["!someroom:test"]) + store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. @@ -218,23 +218,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register mock device list retrieval on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( - return_value={ - "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { + return_value=succeed( + { "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" + remote_self_signing_key: remote_self_signing_key + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, }, - }, - } + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) ) # Resync the device list. -- cgit 1.5.1 From cc9bb3dc3f299d451ab523dea192e48c32e87c68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 22 Jul 2020 12:29:15 -0400 Subject: Convert the message handler to async/await. (#7884) --- changelog.d/7884.misc | 1 + synapse/handlers/message.py | 288 ++++++++++++++------------- tests/events/test_snapshot.py | 36 ++-- tests/replication/tcp/streams/test_events.py | 76 ++++--- tests/storage/test_roommember.py | 56 +++--- tests/storage/test_state.py | 4 +- tests/test_utils/event_injection.py | 28 +-- tests/test_visibility.py | 14 +- tests/unittest.py | 4 +- tests/utils.py | 4 +- 10 files changed, 273 insertions(+), 238 deletions(-) create mode 100644 changelog.d/7884.misc (limited to 'tests') diff --git a/changelog.d/7884.misc b/changelog.d/7884.misc new file mode 100644 index 0000000000..36c7d4de67 --- /dev/null +++ b/changelog.d/7884.misc @@ -0,0 +1 @@ +Convert the message handler to async/await. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c47764a4ce..172a7214b2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,12 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from canonicaljson import encode_canonical_json, json -from twisted.internet import defer -from twisted.internet.defer import succeed from twisted.internet.interfaces import IDelayedCall from synapse import event_auth @@ -41,13 +39,22 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events import EventBase +from synapse.events.builder import EventBuilder +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import Collection, RoomAlias, UserID, create_requester +from synapse.types import ( + Collection, + Requester, + RoomAlias, + StreamToken, + UserID, + create_requester, +) from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func @@ -84,14 +91,22 @@ class MessageHandler(object): "_schedule_next_expiry", self._schedule_next_expiry ) - @defer.inlineCallbacks - def get_room_data( - self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False - ): + async def get_room_data( + self, + user_id: str = None, + room_id: str = None, + event_type: Optional[str] = None, + state_key: str = "", + is_guest: bool = False, + ) -> dict: """ Get data from a room. Args: - event : The room path event + user_id + room_id + event_type + state_key + is_guest Returns: The path data content. Raises: @@ -100,30 +115,29 @@ class MessageHandler(object): ( membership, membership_event_id, - ) = yield self.auth.check_user_in_room_or_world_readable( + ) = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership == Membership.JOIN: - data = yield self.state.get_current_state(room_id, event_type, state_key) + data = await self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) return data - @defer.inlineCallbacks - def get_state_events( + async def get_state_events( self, - user_id, - room_id, - state_filter=StateFilter.all(), - at_token=None, - is_guest=False, - ): + user_id: str, + room_id: str, + state_filter: StateFilter = StateFilter.all(), + at_token: Optional[StreamToken] = None, + is_guest: bool = False, + ) -> List[dict]: """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has left the room return the state events from when they left. If an explicit @@ -131,15 +145,14 @@ class MessageHandler(object): visible. Args: - user_id(str): The user requesting state events. - room_id(str): The room ID to get all state events from. - state_filter (StateFilter): The state filter used to fetch state - from the database. - at_token(StreamToken|None): the stream token of the at which we are requesting + user_id: The user requesting state events. + room_id: The room ID to get all state events from. + state_filter: The state filter used to fetch state from the database. + at_token: the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current state based on the current_state_events table. - is_guest(bool): whether this user is a guest + is_guest: whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] Raises: @@ -153,20 +166,20 @@ class MessageHandler(object): # get_recent_events_for_room operates by topo ordering. This therefore # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) - last_events, _ = yield self.store.get_recent_events_for_room( + last_events, _ = await self.store.get_recent_events_for_room( room_id, end_token=at_token.room_key, limit=1 ) if not last_events: raise NotFoundError("Can't find event for token %s" % (at_token,)) - visible_events = yield filter_events_for_client( + visible_events = await filter_events_for_client( self.storage, user_id, last_events, filter_send_to_client=False ) event = last_events[0] if visible_events: - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] @@ -180,23 +193,23 @@ class MessageHandler(object): ( membership, membership_event_id, - ) = yield self.auth.check_user_in_room_or_world_readable( + ) = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership == Membership.JOIN: - state_ids = yield self.store.get_filtered_current_state_ids( + state_ids = await self.store.get_filtered_current_state_ids( room_id, state_filter=state_filter ) - room_state = yield self.store.get_events(state_ids.values()) + room_state = await self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events( + events = await self._event_serializer.serialize_events( room_state.values(), now, # We don't bother bundling aggregations in when asked for state @@ -205,15 +218,14 @@ class MessageHandler(object): ) return events - @defer.inlineCallbacks - def get_joined_members(self, requester, room_id): + async def get_joined_members(self, requester: Requester, room_id: str) -> dict: """Get all the joined members in the room and their profile information. If the user has left the room return the state events from when they left. Args: - requester(Requester): The user requesting state events. - room_id(str): The room ID to get all state events from. + requester: The user requesting state events. + room_id: The room ID to get all state events from. Returns: A dict of user_id to profile info """ @@ -221,7 +233,7 @@ class MessageHandler(object): if not requester.app_service: # We check AS auth after fetching the room membership, as it # requires us to pull out all joined members anyway. - membership, _ = yield self.auth.check_user_in_room_or_world_readable( + membership, _ = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership != Membership.JOIN: @@ -229,7 +241,7 @@ class MessageHandler(object): "Getting joined members after leaving is not implemented" ) - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there @@ -250,7 +262,7 @@ class MessageHandler(object): for user_id, profile in users_with_profile.items() } - def maybe_schedule_expiry(self, event): + def maybe_schedule_expiry(self, event: EventBase): """Schedule the expiry of an event if there's not already one scheduled, or if the one running is for an event that will expire after the provided timestamp. @@ -259,7 +271,7 @@ class MessageHandler(object): the master process, and therefore needs to be run on there. Args: - event (EventBase): The event to schedule the expiry of. + event: The event to schedule the expiry of. """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) @@ -270,8 +282,7 @@ class MessageHandler(object): # a task scheduled for a timestamp that's sooner than the provided one. self._schedule_expiry_for_event(event.event_id, expiry_ts) - @defer.inlineCallbacks - def _schedule_next_expiry(self): + async def _schedule_next_expiry(self): """Retrieve the ID and the expiry timestamp of the next event to be expired, and schedule an expiry task for it. @@ -279,18 +290,18 @@ class MessageHandler(object): future call to save_expiry_ts can schedule a new expiry task. """ # Try to get the expiry timestamp of the next event to expire. - res = yield self.store.get_next_event_to_expire() + res = await self.store.get_next_event_to_expire() if res: event_id, expiry_ts = res self._schedule_expiry_for_event(event_id, expiry_ts) - def _schedule_expiry_for_event(self, event_id, expiry_ts): + def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int): """Schedule an expiry task for the provided event if there's not already one scheduled at a timestamp that's sooner than the provided one. Args: - event_id (str): The ID of the event to expire. - expiry_ts (int): The timestamp at which to expire the event. + event_id: The ID of the event to expire. + expiry_ts: The timestamp at which to expire the event. """ if self._scheduled_expiry: # If the provided timestamp refers to a time before the scheduled time of the @@ -320,8 +331,7 @@ class MessageHandler(object): event_id, ) - @defer.inlineCallbacks - def _expire_event(self, event_id): + async def _expire_event(self, event_id: str): """Retrieve and expire an event that needs to be expired from the database. If the event doesn't exist in the database, log it and delete the expiry date @@ -336,12 +346,12 @@ class MessageHandler(object): try: # Expire the event if we know about it. This function also deletes the expiry # date from the database in the same database transaction. - yield self.store.expire_event(event_id) + await self.store.expire_event(event_id) except Exception as e: logger.error("Could not expire event %s: %r", event_id, e) # Schedule the expiry of the next event to expire. - yield self._schedule_next_expiry() + await self._schedule_next_expiry() # The duration (in ms) after which rooms should be removed @@ -423,16 +433,15 @@ class EventCreationHandler(object): self._dummy_events_threshold = hs.config.dummy_events_threshold - @defer.inlineCallbacks - def create_event( + async def create_event( self, - requester, - event_dict, - token_id=None, - txn_id=None, + requester: Requester, + event_dict: dict, + token_id: Optional[str] = None, + txn_id: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, - require_consent=True, - ): + require_consent: bool = True, + ) -> Tuple[EventBase, EventContext]: """ Given a dict from a client, create a new event. @@ -443,31 +452,29 @@ class EventCreationHandler(object): Args: requester - event_dict (dict): An entire event - token_id (str) - txn_id (str) - + event_dict: An entire event + token_id + txn_id prev_event_ids: the forward extremities to use as the prev_events for the new event. If None, they will be requested from the database. - - require_consent (bool): Whether to check if the requester has - consented to privacy policy. + require_consent: Whether to check if the requester has + consented to the privacy policy. Raises: ResourceLimitError if server is blocked to some resource being exceeded Returns: - Tuple of created event (FrozenEvent), Context + Tuple of created event, Context """ - yield self.auth.check_auth_blocking(requester.user.to_string()) + await self.auth.check_auth_blocking(requester.user.to_string()) if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version = event_dict["content"]["room_version"] else: try: - room_version = yield self.store.get_room_version_id( + room_version = await self.store.get_room_version_id( event_dict["room_id"] ) except NotFoundError: @@ -488,15 +495,11 @@ class EventCreationHandler(object): try: if "displayname" not in content: - displayname = yield defer.ensureDeferred( - profile.get_displayname(target) - ) + displayname = await profile.get_displayname(target) if displayname is not None: content["displayname"] = displayname if "avatar_url" not in content: - avatar_url = yield defer.ensureDeferred( - profile.get_avatar_url(target) - ) + avatar_url = await profile.get_avatar_url(target) if avatar_url is not None: content["avatar_url"] = avatar_url except Exception as e: @@ -504,9 +507,9 @@ class EventCreationHandler(object): "Failed to get profile information for %r: %s", target, e ) - is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester) + is_exempt = await self._is_exempt_from_privacy_policy(builder, requester) if require_consent and not is_exempt: - yield self.assert_accepted_privacy_policy(requester) + await self.assert_accepted_privacy_policy(requester) if token_id is not None: builder.internal_metadata.token_id = token_id @@ -514,7 +517,7 @@ class EventCreationHandler(object): if txn_id is not None: builder.internal_metadata.txn_id = txn_id - event, context = yield self.create_new_client_event( + event, context = await self.create_new_client_event( builder=builder, requester=requester, prev_event_ids=prev_event_ids, ) @@ -530,10 +533,10 @@ class EventCreationHandler(object): # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( - yield self.store.get_event(prev_event_id, allow_none=True) + await self.store.get_event(prev_event_id, allow_none=True) if prev_event_id else None ) @@ -556,37 +559,36 @@ class EventCreationHandler(object): return (event, context) - def _is_exempt_from_privacy_policy(self, builder, requester): + async def _is_exempt_from_privacy_policy( + self, builder: EventBuilder, requester: Requester + ) -> bool: """"Determine if an event to be sent is exempt from having to consent to the privacy policy Args: - builder (synapse.events.builder.EventBuilder): event being created - requester (Requster): user requesting this event + builder: event being created + requester: user requesting this event Returns: - Deferred[bool]: true if the event can be sent without the user - consenting + true if the event can be sent without the user consenting """ # the only thing the user can do is join the server notices room. if builder.type == EventTypes.Member: membership = builder.content.get("membership", None) if membership == Membership.JOIN: - return self._is_server_notices_room(builder.room_id) + return await self._is_server_notices_room(builder.room_id) elif membership == Membership.LEAVE: # the user is always allowed to leave (but not kick people) return builder.state_key == requester.user.to_string() - return succeed(False) + return False - @defer.inlineCallbacks - def _is_server_notices_room(self, room_id): + async def _is_server_notices_room(self, room_id: str) -> bool: if self.config.server_notices_mxid is None: return False - user_ids = yield self.store.get_users_in_room(room_id) + user_ids = await self.store.get_users_in_room(room_id) return self.config.server_notices_mxid in user_ids - @defer.inlineCallbacks - def assert_accepted_privacy_policy(self, requester): + async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy Called when the given user is about to do something that requires @@ -595,12 +597,10 @@ class EventCreationHandler(object): raised. Args: - requester (synapse.types.Requester): - The user making the request + requester: The user making the request Returns: - Deferred[None]: returns normally if the user has consented or is - exempt + Returns normally if the user has consented or is exempt Raises: ConsentNotGivenError: if the user has not given consent yet @@ -621,7 +621,7 @@ class EventCreationHandler(object): ): return - u = yield self.store.get_user_by_id(user_id) + u = await self.store.get_user_by_id(user_id) assert u is not None if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): # support and bot users are not required to consent @@ -639,16 +639,20 @@ class EventCreationHandler(object): raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) async def send_nonmember_event( - self, requester, event, context, ratelimit=True + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, ) -> int: """ Persists and notifies local clients and federation of an event. Args: - event (FrozenEvent) the event to send. - context (Context) the context of the event. - ratelimit (bool): Whether to rate limit this send. - is_guest (bool): Whether the sender is a guest. + requester + event the event to send. + context: the context of the event. + ratelimit: Whether to rate limit this send. Return: The stream_id of the persisted event. @@ -676,19 +680,20 @@ class EventCreationHandler(object): requester=requester, event=event, context=context, ratelimit=ratelimit ) - @defer.inlineCallbacks - def deduplicate_state_event(self, event, context): + async def deduplicate_state_event( + self, event: EventBase, context: EventContext + ) -> None: """ Checks whether event is in the latest resolved state in context. If so, returns the version of the event in context. Otherwise, returns None. """ - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: return - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if not prev_event: return @@ -700,7 +705,11 @@ class EventCreationHandler(object): return async def create_and_send_nonmember_event( - self, requester, event_dict, ratelimit=True, txn_id=None + self, + requester: Requester, + event_dict: EventBase, + ratelimit: bool = True, + txn_id: Optional[str] = None, ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -730,17 +739,17 @@ class EventCreationHandler(object): return event, stream_id @measure_func("create_new_client_event") - @defer.inlineCallbacks - def create_new_client_event( - self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None - ): + async def create_new_client_event( + self, + builder: EventBuilder, + requester: Optional[Requester] = None, + prev_event_ids: Optional[Collection[str]] = None, + ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client Args: - builder (EventBuilder): - - requester (synapse.types.Requester|None): - + builder: + requester: prev_event_ids: the forward extremities to use as the prev_events for the new event. @@ -748,7 +757,7 @@ class EventCreationHandler(object): If None, they will be requested from the database. Returns: - Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] + Tuple of created event, context """ if prev_event_ids is not None: @@ -757,10 +766,10 @@ class EventCreationHandler(object): % (len(prev_event_ids),) ) else: - prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id) + prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) - event = yield builder.build(prev_event_ids=prev_event_ids) - context = yield self.state.compute_event_context(event) + event = await builder.build(prev_event_ids=prev_event_ids) + context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service @@ -774,7 +783,7 @@ class EventCreationHandler(object): relates_to = relation["event_id"] aggregation_key = relation["key"] - already_exists = yield self.store.has_user_annotated_event( + already_exists = await self.store.has_user_annotated_event( relates_to, event.type, aggregation_key, event.sender ) if already_exists: @@ -786,7 +795,12 @@ class EventCreationHandler(object): @measure_func("handle_new_client_event") async def handle_new_client_event( - self, requester, event, context, ratelimit=True, extra_users=[] + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + extra_users: List[UserID] = [], ) -> int: """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -795,11 +809,11 @@ class EventCreationHandler(object): processing. Args: - requester (Requester) - event (FrozenEvent) - context (EventContext) - ratelimit (bool) - extra_users (list(UserID)): Any extra users to notify about event + requester + event + context + ratelimit + extra_users: Any extra users to notify about event Return: The stream_id of the persisted event. @@ -878,10 +892,9 @@ class EventCreationHandler(object): self.store.remove_push_actions_from_staging, event.event_id ) - @defer.inlineCallbacks - def _validate_canonical_alias( - self, directory_handler, room_alias_str, expected_room_id - ): + async def _validate_canonical_alias( + self, directory_handler, room_alias_str: str, expected_room_id: str + ) -> None: """ Ensure that the given room alias points to the expected room ID. @@ -892,9 +905,7 @@ class EventCreationHandler(object): """ room_alias = RoomAlias.from_string(room_alias_str) try: - mapping = yield defer.ensureDeferred( - directory_handler.get_association(room_alias) - ) + mapping = await directory_handler.get_association(room_alias) except SynapseError as e: # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. if e.errcode == Codes.NOT_FOUND: @@ -913,7 +924,12 @@ class EventCreationHandler(object): ) async def persist_and_notify_client_event( - self, requester, event, context, ratelimit=True, extra_users=[] + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + extra_users: List[UserID] = [], ) -> int: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1106,7 +1122,7 @@ class EventCreationHandler(object): return event_stream_id - async def _bump_active_time(self, user): + async def _bump_active_time(self, user: UserID) -> None: try: presence = self.hs.get_presence_handler() await presence.bump_presence_active_time(user) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 640f5f3bce..3a80626224 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -41,8 +41,10 @@ class TestEventContext(unittest.HomeserverTestCase): serialize/deserialize. """ - event, context = create_event( - self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + event, context = self.get_success( + create_event( + self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + ) ) self._check_serialize_deserialize(event, context) @@ -51,12 +53,14 @@ class TestEventContext(unittest.HomeserverTestCase): """Test that an EventContext for a state event (with not previous entry) is the same after serialize/deserialize. """ - event, context = create_event( - self.hs, - room_id=self.room_id, - type="m.test", - sender=self.user_id, - state_key="", + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, + state_key="", + ) ) self._check_serialize_deserialize(event, context) @@ -65,13 +69,15 @@ class TestEventContext(unittest.HomeserverTestCase): """Test that an EventContext for a state event (which replaces a previous entry) is the same after serialize/deserialize. """ - event, context = create_event( - self.hs, - room_id=self.room_id, - type="m.room.member", - sender=self.user_id, - state_key=self.user_id, - content={"membership": "leave"}, + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.user_id, + state_key=self.user_id, + content={"membership": "leave"}, + ) ) self._check_serialize_deserialize(event, context) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 097e1653b4..c9998e88e6 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -119,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase): OTHER_USER = "@other_user:localhost" # have the user join - inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + self.get_success( + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + ) # Update existing power levels with mod at PL50 pls = self.helper.get_state( @@ -157,14 +159,16 @@ class EventsStreamTestCase(BaseStreamTestCase): # roll back all the state by de-modding the user prev_events = fork_point pls["users"][OTHER_USER] = 0 - pl_event = inject_event( - self.hs, - prev_event_ids=prev_events, - type=EventTypes.PowerLevels, - state_key="", - sender=self.user_id, - room_id=self.room_id, - content=pls, + pl_event = self.get_success( + inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) ) # one more bit of state that doesn't get rolled back @@ -268,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase): # have the users join for u in user_ids: - inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + self.get_success( + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + ) # Update existing power levels with mod at PL50 pls = self.helper.get_state( @@ -306,14 +312,16 @@ class EventsStreamTestCase(BaseStreamTestCase): pl_events = [] for u in user_ids: pls["users"][u] = 0 - e = inject_event( - self.hs, - prev_event_ids=prev_events, - type=EventTypes.PowerLevels, - state_key="", - sender=self.user_id, - room_id=self.room_id, - content=pls, + e = self.get_success( + inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) ) prev_events = [e.event_id] pl_events.append(e) @@ -434,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase): body = "event %i" % (self.event_count,) self.event_count += 1 - return inject_event( - self.hs, - room_id=self.room_id, - sender=sender, - type="test_event", - content={"body": body}, - **kwargs + return self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) ) def _inject_state_event( @@ -459,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase): if body is None: body = "state event %s" % (state_key,) - return inject_event( - self.hs, - room_id=self.room_id, - sender=sender, - type="test_state_event", - state_key=state_key, - content={"body": body}, + return self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5dd46005e6..f282921538 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -118,18 +118,22 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): def test_get_joined_users_from_context(self): room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - bob_event = event_injection.inject_member_event( - self.hs, room, self.u_bob, Membership.JOIN + bob_event = self.get_success( + event_injection.inject_member_event( + self.hs, room, self.u_bob, Membership.JOIN + ) ) # first, create a regular event - event, context = event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[bob_event.event_id], - type="m.test.1", - content={}, + event, context = self.get_success( + event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[bob_event.event_id], + type="m.test.1", + content={}, + ) ) users = self.get_success( @@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Regression test for #7376: create a state event whose key matches bob's # user_id, but which is *not* a membership event, and persist that; then check # that `get_joined_users_from_context` returns the correct users for the next event. - non_member_event = event_injection.inject_event( - self.hs, - room_id=room, - sender=self.u_bob, - prev_event_ids=[bob_event.event_id], - type="m.test.2", - state_key=self.u_bob, - content={}, + non_member_event = self.get_success( + event_injection.inject_event( + self.hs, + room_id=room, + sender=self.u_bob, + prev_event_ids=[bob_event.event_id], + type="m.test.2", + state_key=self.u_bob, + content={}, + ) ) - event, context = event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[non_member_event.event_id], - type="m.test.3", - content={}, + event, context = self.get_success( + event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[non_member_event.event_id], + type="m.test.3", + content={}, + ) ) users = self.get_success( self.store.get_joined_users_from_context(event, context) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 0b88308ff4..a0e133cd4a 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -64,8 +64,8 @@ class StateStoreTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 43297b530c..8522c6fc09 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -22,14 +22,12 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import Collection -from tests.test_utils import get_awaitable_result - """ Utility functions for poking events into the storage of the server under test. """ -def inject_member_event( +async def inject_member_event( hs: synapse.server.HomeServer, room_id: str, sender: str, @@ -46,7 +44,7 @@ def inject_member_event( if extra_content: content.update(extra_content) - return inject_event( + return await inject_event( hs, room_id=room_id, type=EventTypes.Member, @@ -57,7 +55,7 @@ def inject_member_event( ) -def inject_event( +async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, @@ -72,37 +70,27 @@ def inject_event( prev_event_ids: prev_events for the event. If not specified, will be looked up kwargs: fields for the event to be created """ - test_reactor = hs.get_reactor() - - event, context = create_event(hs, room_version, prev_event_ids, **kwargs) + event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - d = hs.get_storage().persistence.persist_event(event, context) - test_reactor.advance(0) - get_awaitable_result(d) + await hs.get_storage().persistence.persist_event(event, context) return event -def create_event( +async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: - test_reactor = hs.get_reactor() - if room_version is None: - d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) - test_reactor.advance(0) - room_version = get_awaitable_result(d) + room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs ) - d = hs.get_event_creation_handler().create_new_client_event( + event, context = await hs.get_event_creation_handler().create_new_client_event( builder, prev_event_ids=prev_event_ids ) - test_reactor.advance(0) - event, context = get_awaitable_result(d) return event, context diff --git a/tests/test_visibility.py b/tests/test_visibility.py index f7381b2885..b371efc0df 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # # before we do that, we persist some other events to act as state. - self.inject_visibility("@admin:hs", "joined") + yield self.inject_visibility("@admin:hs", "joined") for i in range(0, 10): yield self.inject_room_member("@resident%i:hs" % i) @@ -137,8 +137,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) return event @@ -158,8 +158,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) @@ -179,8 +179,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) diff --git a/tests/unittest.py b/tests/unittest.py index 3175a3fa02..68d2586efd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase): user: MXID of the user to inject the membership for. membership: The membership type. """ - event_injection.inject_member_event(self.hs, room, user, membership) + self.get_success( + event_injection.inject_member_event(self.hs, room, user, membership) + ) class FederatingHomeserverTestCase(HomeserverTestCase): diff --git a/tests/utils.py b/tests/utils.py index 4d17355a5c..ac643679aa 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -671,6 +671,8 @@ def create_room(hs, room_id, creator_id): }, ) - event, context = yield event_creation_handler.create_new_client_event(builder) + event, context = yield defer.ensureDeferred( + event_creation_handler.create_new_client_event(builder) + ) yield persistence_store.persist_event(event, context) -- cgit 1.5.1 From 68cd935826b912aea365de8b6aa589e35360cc85 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Jul 2020 07:05:57 -0400 Subject: Convert the federation agent and related code to async/await. (#7874) --- changelog.d/7874.misc | 1 + synapse/http/federation/matrix_federation_agent.py | 16 +++---- synapse/http/federation/srv_resolver.py | 10 ++--- .../federation/test_matrix_federation_agent.py | 51 +++++++++++++--------- tests/http/federation/test_srv_resolver.py | 26 +++++------ 5 files changed, 51 insertions(+), 53 deletions(-) create mode 100644 changelog.d/7874.misc (limited to 'tests') diff --git a/changelog.d/7874.misc b/changelog.d/7874.misc new file mode 100644 index 0000000000..f75c8d1843 --- /dev/null +++ b/changelog.d/7874.misc @@ -0,0 +1 @@ +Convert the federation agent and related code to async/await. diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index c5fc746f2f..0c02648015 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -15,6 +15,7 @@ import logging import urllib +from typing import List from netaddr import AddrFormatError, IPAddress from zope.interface import implementer @@ -236,11 +237,10 @@ class MatrixHostnameEndpoint(object): return run_in_background(self._do_connect, protocol_factory) - @defer.inlineCallbacks - def _do_connect(self, protocol_factory): + async def _do_connect(self, protocol_factory): first_exception = None - server_list = yield self._resolve_server() + server_list = await self._resolve_server() for server in server_list: host = server.host @@ -251,7 +251,7 @@ class MatrixHostnameEndpoint(object): endpoint = HostnameEndpoint(self._reactor, host, port) if self._tls_options: endpoint = wrapClientTLS(self._tls_options, endpoint) - result = yield make_deferred_yieldable( + result = await make_deferred_yieldable( endpoint.connect(protocol_factory) ) @@ -271,13 +271,9 @@ class MatrixHostnameEndpoint(object): # to try and if that doesn't work then we'll have an exception. raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,)) - @defer.inlineCallbacks - def _resolve_server(self): + async def _resolve_server(self) -> List[Server]: """Resolves the server name to a list of hosts and ports to attempt to connect to. - - Returns: - Deferred[list[Server]] """ if self._parsed_uri.scheme != b"matrix": @@ -298,7 +294,7 @@ class MatrixHostnameEndpoint(object): if port or _is_ip_literal(host): return [Server(host, port or 8448)] - server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host) + server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host) if server_list: return server_list diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 021b233a7d..2ede90a9b1 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -17,10 +17,10 @@ import logging import random import time +from typing import List import attr -from twisted.internet import defer from twisted.internet.error import ConnectError from twisted.names import client, dns from twisted.names.error import DNSNameError, DomainError @@ -113,16 +113,14 @@ class SrvResolver(object): self._cache = cache self._get_time = get_time - @defer.inlineCallbacks - def resolve_service(self, service_name): + async def resolve_service(self, service_name: bytes) -> List[Server]: """Look up a SRV record Args: service_name (bytes): record to look up Returns: - Deferred[list[Server]]: - a list of the SRV records, or an empty list if none found + a list of the SRV records, or an empty list if none found """ now = int(self._get_time()) @@ -136,7 +134,7 @@ class SrvResolver(object): return _sort_server_list(servers) try: - answers, _, _ = yield make_deferred_yieldable( + answers, _, _ = await make_deferred_yieldable( self._dns_client.lookupService(service_name) ) except DNSNameError: diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 954e059e76..69945a8f98 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -67,6 +67,14 @@ def get_connection_factory(): return test_server_connection_factory +# Once Async Mocks or lambdas are supported this can go away. +def generate_resolve_service(result): + async def resolve_service(_): + return result + + return resolve_service + + class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() @@ -373,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the certificate on the server doesn't match the hostname """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv1"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv1/foo/bar") @@ -456,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name has no port, no SRV, and no well-known """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -510,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the .well-known delegates elsewhere """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -572,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -661,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -717,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # the config left to the default, which will not trust it (since the # presented cert is signed by a test CA) - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" config = default_config("test", parse=True) @@ -764,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when there is a single SRV record """ - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self.reactor.lookups["srvtarget"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -819,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self._handle_well_known_connection( client_factory, @@ -861,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" @@ -922,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"xn--trget-3qa.com", port=8443) # târget.com - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com + ) self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") @@ -1087,11 +1095,12 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails. """ - - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"target.com", port=8443), - Server(host=b"target.com", port=8444), - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + ) self.reactor.lookups["target.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index babc201643..fee2985d35 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import LoggingContext, current_context from tests import unittest from tests.utils import MockClock @@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase): with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) - - self.assertNoResult(resolve_d) - - # should have reset to the sentinel context - self.assertIs(current_context(), SENTINEL_CONTEXT) - - result = yield resolve_d + result = yield defer.ensureDeferred(resolve_d) # should have restored our context self.assertIs(current_context(), ctx) @@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {service_name: [entry]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) dns_client_mock.lookupService.assert_called_once_with(service_name) @@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client=dns_client_mock, cache=cache, get_time=clock.time ) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) self.assertFalse(dns_client_mock.lookupService.called) @@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase): resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) with self.assertRaises(error.DNSServerError): - yield resolver.resolve_service(service_name) + yield defer.ensureDeferred(resolver.resolve_service(service_name)) @defer.inlineCallbacks def test_name_error(self): @@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) self.assertEquals(len(servers), 0) self.assertEquals(len(cache), 0) @@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolver.resolve_service(service_name) - self.assertNoResult(resolve_d) + # Old versions of Twisted don't have an ensureDeferred in failureResultOf. + resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name)) # returning a single "." should make the lookup fail with a ConenctError lookup_deferred.callback( @@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolver.resolve_service(service_name) - self.assertNoResult(resolve_d) + # Old versions of Twisted don't have an ensureDeferred in successResultOf. + resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name)) lookup_deferred.callback( ( -- cgit 1.5.1 From 1ec688bf21cd1368a2bb86c2de977daf148eecc3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 24 Jul 2020 09:55:47 +0100 Subject: Downgrade warning on client disconnect to INFO (#7928) Clients disconnecting before we finish processing the request happens from time to time. We don't need to yell about it --- changelog.d/7928.misc | 1 + synapse/http/site.py | 4 +--- tests/test_server.py | 59 +-------------------------------------------------- 3 files changed, 3 insertions(+), 61 deletions(-) create mode 100644 changelog.d/7928.misc (limited to 'tests') diff --git a/changelog.d/7928.misc b/changelog.d/7928.misc new file mode 100644 index 0000000000..5f3aa5de0a --- /dev/null +++ b/changelog.d/7928.misc @@ -0,0 +1 @@ +When a client disconnects, don't log it as 'Error processing request'. diff --git a/synapse/http/site.py b/synapse/http/site.py index cbc37eac6e..6f3b2258cc 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -215,9 +215,7 @@ class SynapseRequest(Request): # It's useful to log it here so that we can get an idea of when # the client disconnects. with PreserveLoggingContext(self.logcontext): - logger.warning( - "Error processing request %r: %s %s", self, reason.type, reason.value - ) + logger.info("Connection from client lost before response was sent") if not self._is_processing: self._finished_processing() diff --git a/tests/test_server.py b/tests/test_server.py index 030f58cbdc..42cada8964 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -12,26 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import re -from io import StringIO from twisted.internet.defer import Deferred -from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.config.server import parse_listener_def from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource -from synapse.http.site import SynapseSite, logger +from synapse.http.site import SynapseSite from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock from tests import unittest from tests.server import ( - FakeTransport, ThreadedMemoryReactorClock, make_request, render, @@ -318,54 +312,3 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.assertEqual(location_headers, [b"/no/over/there"]) cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] self.assertEqual(cookies_headers, [b"session=yespls"]) - - -class SiteTestCase(unittest.HomeserverTestCase): - def test_lose_connection(self): - """ - We log the URI correctly redacted when we lose the connection. - """ - - class HangingResource(Resource): - """ - A Resource that strategically hangs, as if it were processing an - answer. - """ - - def render(self, request): - return NOT_DONE_YET - - # Set up a logging handler that we can inspect afterwards - output = StringIO() - handler = logging.StreamHandler(output) - logger.addHandler(handler) - old_level = logger.level - logger.setLevel(10) - self.addCleanup(logger.setLevel, old_level) - self.addCleanup(logger.removeHandler, handler) - - # Make a resource and a Site, the resource will hang and allow us to - # time out the request while it's 'processing' - base_resource = Resource() - base_resource.putChild(b"", HangingResource()) - site = SynapseSite( - "test", "site_tag", self.hs.config.listeners[0], base_resource, "1.0" - ) - - server = site.buildProtocol(None) - client = AccumulatingProtocol() - client.makeConnection(FakeTransport(server, self.reactor)) - server.makeConnection(FakeTransport(client, self.reactor)) - - # Send a request with an access token that will get redacted - server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n") - self.pump() - - # Lose the connection - e = Failure(Exception("Failed123")) - server.connectionLost(e) - handler.flush() - - # Our access token is redacted and the failure reason is logged. - self.assertIn("/?access_token=", output.getvalue()) - self.assertIn("Failed123", output.getvalue()) -- cgit 1.5.1 From 6a080ea184844f6ee9412a8d6170eb7ff2e5dd56 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 07:08:07 -0400 Subject: Return an empty body for OPTIONS requests. (#7886) --- changelog.d/7886.misc | 1 + synapse/http/server.py | 24 +++++------------------- tests/test_server.py | 12 ++++++------ 3 files changed, 12 insertions(+), 25 deletions(-) create mode 100644 changelog.d/7886.misc (limited to 'tests') diff --git a/changelog.d/7886.misc b/changelog.d/7886.misc new file mode 100644 index 0000000000..e729ab2451 --- /dev/null +++ b/changelog.d/7886.misc @@ -0,0 +1 @@ +Return an empty body for OPTIONS requests. diff --git a/synapse/http/server.py b/synapse/http/server.py index 8e003689c4..d4f9ad6e67 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -442,21 +442,6 @@ class StaticResource(File): return super().render_GET(request) -def _options_handler(request): - """Request handler for OPTIONS requests - - This is a request handler suitable for return from - _get_handler_for_request. It returns a 200 and an empty body. - - Args: - request (twisted.web.http.Request): - - Returns: - Tuple[int, dict]: http code, response body. - """ - return 200, {} - - def _unrecognised_request_handler(request): """Request handler for unrecognised requests @@ -490,11 +475,12 @@ class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" def render_OPTIONS(self, request): - code, response_json_object = _options_handler(request) + request.setResponseCode(204) + request.setHeader(b"Content-Length", b"0") - return respond_with_json( - request, code, response_json_object, send_cors=True, canonical_json=False, - ) + set_cors_headers(request) + + return b"" def getChildWithDefault(self, path, request): if request.method == b"OPTIONS": diff --git a/tests/test_server.py b/tests/test_server.py index 42cada8964..073b2362cc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -193,10 +193,10 @@ class OptionsResourceTests(unittest.TestCase): return channel def test_unknown_options_request(self): - """An OPTIONS requests to an unknown URL still returns 200 OK.""" + """An OPTIONS requests to an unknown URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/foo/") - self.assertEqual(channel.result["code"], b"200") - self.assertEqual(channel.result["body"], b"{}") + self.assertEqual(channel.result["code"], b"204") + self.assertNotIn("body", channel.result) # Ensure the correct CORS headers have been added self.assertTrue( @@ -213,10 +213,10 @@ class OptionsResourceTests(unittest.TestCase): ) def test_known_options_request(self): - """An OPTIONS requests to an known URL still returns 200 OK.""" + """An OPTIONS requests to an known URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/res/") - self.assertEqual(channel.result["code"], b"200") - self.assertEqual(channel.result["body"], b"{}") + self.assertEqual(channel.result["code"], b"204") + self.assertNotIn("body", channel.result) # Ensure the correct CORS headers have been added self.assertTrue( -- cgit 1.5.1 From 5ea29d7f850b6d2acbbfaf2e81bc5f0625411320 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 09:39:02 -0400 Subject: Convert more of the media code to async/await (#7873) --- changelog.d/7873.misc | 1 + synapse/rest/media/v1/_base.py | 15 ++++---- synapse/rest/media/v1/media_storage.py | 60 +++++++++++++++++-------------- tests/rest/media/v1/test_media_storage.py | 5 ++- 4 files changed, 47 insertions(+), 34 deletions(-) create mode 100644 changelog.d/7873.misc (limited to 'tests') diff --git a/changelog.d/7873.misc b/changelog.d/7873.misc new file mode 100644 index 0000000000..58260764e7 --- /dev/null +++ b/changelog.d/7873.misc @@ -0,0 +1 @@ +Convert more media code to async/await. diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 595849f9d5..9a847130c0 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -18,7 +18,6 @@ import logging import os import urllib -from twisted.internet import defer from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error @@ -77,8 +76,9 @@ def respond_404(request): ) -@defer.inlineCallbacks -def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None): +async def respond_with_file( + request, media_type, file_path, file_size=None, upload_name=None +): logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): @@ -89,7 +89,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: - yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) + await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) finish_request(request) else: @@ -198,8 +198,9 @@ def _can_encode_filename_as_token(x): return True -@defer.inlineCallbacks -def respond_with_responder(request, responder, media_type, file_size, upload_name=None): +async def respond_with_responder( + request, responder, media_type, file_size, upload_name=None +): """Responds to the request with given responder. If responder is None then returns 404. @@ -218,7 +219,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam add_file_headers(request, media_type, file_size, upload_name) try: with responder: - yield responder.write_to_consumer(request) + await responder.write_to_consumer(request) except Exception as e: # The majority of the time this will be due to the client having gone # away. Unfortunately, Twisted simply throws a generic exception at us diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 79cb0dddbe..66bc1c3360 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -14,17 +14,18 @@ # limitations under the License. import contextlib +import inspect import logging import os import shutil +from typing import Optional -from twisted.internet import defer from twisted.protocols.basic import FileSender from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer -from ._base import Responder +from ._base import FileInfo, Responder logger = logging.getLogger(__name__) @@ -46,25 +47,24 @@ class MediaStorage(object): self.filepaths = filepaths self.storage_providers = storage_providers - @defer.inlineCallbacks - def store_file(self, source, file_info): + async def store_file(self, source, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other configured storage providers Args: source: A file like object that should be written - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Returns: - Deferred[str]: the file path written to in the primary media store + the file path written to in the primary media store """ with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - yield defer_to_thread( + await defer_to_thread( self.hs.get_reactor(), _write_file_synchronously, source, f ) - yield finish_cb() + await finish_cb() return fname @@ -75,7 +75,7 @@ class MediaStorage(object): Actually yields a 3-tuple (file, fname, finish_cb), where file is a file like object that can be written to, fname is the absolute path of file - on disk, and finish_cb is a function that returns a Deferred. + on disk, and finish_cb is a function that returns an awaitable. fname can be used to read the contents from after upload, e.g. to generate thumbnails. @@ -91,7 +91,7 @@ class MediaStorage(object): with media_storage.store_into_file(info) as (f, fname, finish_cb): # .. write into f ... - yield finish_cb() + await finish_cb() """ path = self._file_info_to_path(file_info) @@ -103,10 +103,13 @@ class MediaStorage(object): finished_called = [False] - @defer.inlineCallbacks - def finish(): + async def finish(): for provider in self.storage_providers: - yield provider.store_file(path, file_info) + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + result = provider.store_file(path, file_info) + if inspect.isawaitable(result): + await result finished_called[0] = True @@ -123,17 +126,15 @@ class MediaStorage(object): if not finished_called: raise Exception("Finished callback not called") - @defer.inlineCallbacks - def fetch_media(self, file_info): + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: """Attempts to fetch media described by file_info from the local cache and configured storage providers. Args: - file_info (FileInfo) + file_info Returns: - Deferred[Responder|None]: Returns a Responder if the file was found, - otherwise None. + Returns a Responder if the file was found, otherwise None. """ path = self._file_info_to_path(file_info) @@ -142,23 +143,26 @@ class MediaStorage(object): return FileResponder(open(local_path, "rb")) for provider in self.storage_providers: - res = yield provider.fetch(path, file_info) + res = provider.fetch(path, file_info) + # Fetch is supposed to return an Awaitable, but guard against + # improper implementations. + if inspect.isawaitable(res): + res = await res if res: logger.debug("Streaming %s from %s", path, provider) return res return None - @defer.inlineCallbacks - def ensure_media_is_in_local_cache(self, file_info): + async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: """Ensures that the given file is in the local cache. Attempts to download it from storage providers if it isn't. Args: - file_info (FileInfo) + file_info Returns: - Deferred[str]: Full path to local file + Full path to local file """ path = self._file_info_to_path(file_info) local_path = os.path.join(self.local_media_directory, path) @@ -170,14 +174,18 @@ class MediaStorage(object): os.makedirs(dirname) for provider in self.storage_providers: - res = yield provider.fetch(path, file_info) + res = provider.fetch(path, file_info) + # Fetch is supposed to return an Awaitable, but guard against + # improper implementations. + if inspect.isawaitable(res): + res = await res if res: with res: consumer = BackgroundFileConsumer( open(local_path, "wb"), self.hs.get_reactor() ) - yield res.write_to_consumer(consumer) - yield consumer.wait() + await res.write_to_consumer(consumer) + await consumer.wait() return local_path raise Exception("file could not be found") diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 66fa5978b2..f4f3e56777 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -26,6 +26,7 @@ import attr from parameterized import parameterized_class from PIL import Image as Image +from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.logging.context import make_deferred_yieldable @@ -77,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase): # This uses a real blocking threadpool so we have to wait for it to be # actually done :/ - x = self.media_storage.ensure_media_is_in_local_cache(file_info) + x = defer.ensureDeferred( + self.media_storage.ensure_media_is_in_local_cache(file_info) + ) # Hotloop until the threadpool does its job... self.wait_on_thread(x) -- cgit 1.5.1 From b975fa2e9952f1f8ac2cddb15c287768bf9b0b4e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 10:59:51 -0400 Subject: Convert state resolution to async/await (#7942) --- changelog.d/7942.misc | 1 + synapse/api/auth.py | 12 ++- synapse/events/builder.py | 4 +- synapse/federation/sender/__init__.py | 4 +- synapse/handlers/presence.py | 4 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/state/__init__.py | 95 ++++++++---------- synapse/state/v1.py | 15 ++- synapse/state/v2.py | 107 ++++++++++----------- synapse/storage/data_stores/main/push_rule.py | 2 +- synapse/storage/data_stores/main/roommember.py | 2 +- synapse/storage/data_stores/main/user_directory.py | 4 +- synapse/storage/persist_events.py | 5 +- tests/federation/test_federation_sender.py | 19 ++-- tests/state/test_v2.py | 17 ++-- tests/storage/test_room.py | 8 +- tests/test_state.py | 72 ++++++++------ tests/test_utils/__init__.py | 7 +- 18 files changed, 198 insertions(+), 184 deletions(-) create mode 100644 changelog.d/7942.misc (limited to 'tests') diff --git a/changelog.d/7942.misc b/changelog.d/7942.misc new file mode 100644 index 0000000000..b504cf4e6f --- /dev/null +++ b/changelog.d/7942.misc @@ -0,0 +1 @@ +Convert state resolution to async/await. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 40dc62ef6c..b53e8451e5 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -127,8 +127,10 @@ class Auth(object): if current_state: member = current_state.get((EventTypes.Member, user_id), None) else: - member = yield self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id + member = yield defer.ensureDeferred( + self.state.get_current_state( + room_id=room_id, event_type=EventTypes.Member, state_key=user_id + ) ) membership = member.membership if member else None @@ -665,8 +667,10 @@ class Auth(object): ) return member_event.membership, member_event.event_id except AuthError: - visibility = yield self.state.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" + visibility = yield defer.ensureDeferred( + self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) ) if ( visibility diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 92aadfe7ef..0bb216419a 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -106,8 +106,8 @@ class EventBuilder(object): Deferred[FrozenEvent] """ - state_ids = yield self._state.get_current_state_ids( - self.room_id, prev_event_ids + state_ids = yield defer.ensureDeferred( + self._state.get_current_state_ids(self.room_id, prev_event_ids) ) auth_ids = yield self._auth.compute_auth_events(self, state_ids) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 99ce73e081..ba4ddd2370 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -330,7 +330,9 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = yield self.state.get_current_hosts_in_room(room_id) + domains = yield defer.ensureDeferred( + self.state.get_current_hosts_in_room(room_id) + ) domains = [ d for d in domains diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 8e99c83d9d..b3a3bb8c3f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -928,8 +928,8 @@ class PresenceHandler(BasePresenceHandler): # TODO: Check that this is actually a new server joining the # room. - user_ids = await self.state.get_current_users_in_room(room_id) - user_ids = list(filter(self.is_mine_id, user_ids)) + users = await self.state.get_current_users_in_room(room_id) + user_ids = list(filter(self.is_mine_id, users)) states_d = await self.current_state_for_users(user_ids) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 43ffe6faf0..472ddf9f7d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -304,7 +304,9 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred( + context.get_current_state_ids() + ) push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 495d9f04c8..25ccef5aa5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,14 +16,12 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Set +from typing import Awaitable, Dict, Iterable, List, Optional, Set import attr from frozendict import frozendict from prometheus_client import Histogram -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events import EventBase @@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.roommember import ProfileInfo from synapse.types import StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer @@ -108,8 +107,7 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def get_current_state( + async def get_current_state( self, room_id, event_type=None, state_key="", latest_event_ids=None ): """ Retrieves the current state for the room. This is done by @@ -126,20 +124,20 @@ class StateHandler(object): map from (type, state_key) to event """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state") - ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state if event_type: event_id = state.get((event_type, state_key)) event = None if event_id: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) return event - state_map = yield self.store.get_events( + state_map = await self.store.get_events( list(state.values()), get_prev_content=False ) state = { @@ -148,8 +146,7 @@ class StateHandler(object): return state - @defer.inlineCallbacks - def get_current_state_ids(self, room_id, latest_event_ids=None): + async def get_current_state_ids(self, room_id, latest_event_ids=None): """Get the current state, or the state at a set of events, for a room Args: @@ -164,41 +161,38 @@ class StateHandler(object): (event_type, state_key) -> event_id """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state return state - @defer.inlineCallbacks - def get_current_users_in_room(self, room_id, latest_event_ids=None): + async def get_current_users_in_room( + self, room_id: str, latest_event_ids: Optional[List[str]] = None + ) -> Dict[str, ProfileInfo]: """ Get the users who are currently in a room. Args: - room_id (str): The ID of the room. - latest_event_ids (List[str]|None): Precomputed list of latest - event IDs. Will be computed if None. + room_id: The ID of the room. + latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: - Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their - profileinfo. + Dictionary of user IDs to their profileinfo. """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_users_in_room") - entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_state(room_id, entry) + entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + joined_users = await self.store.get_joined_users_from_state(room_id, entry) return joined_users - @defer.inlineCallbacks - def get_current_hosts_in_room(self, room_id): - event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - return (yield self.get_hosts_in_room_at_events(room_id, event_ids)) + async def get_current_hosts_in_room(self, room_id): + event_ids = await self.store.get_latest_event_ids_in_room(room_id) + return await self.get_hosts_in_room_at_events(room_id, event_ids) - @defer.inlineCallbacks - def get_hosts_in_room_at_events(self, room_id, event_ids): + async def get_hosts_in_room_at_events(self, room_id, event_ids): """Get the hosts that were in a room at the given event ids Args: @@ -208,12 +202,11 @@ class StateHandler(object): Returns: Deferred[list[str]]: the hosts in the room at the given events """ - entry = yield self.resolve_state_groups_for_events(room_id, event_ids) - joined_hosts = yield self.store.get_joined_hosts(room_id, entry) + entry = await self.resolve_state_groups_for_events(room_id, event_ids) + joined_hosts = await self.store.get_joined_hosts(room_id, entry) return joined_hosts - @defer.inlineCallbacks - def compute_event_context( + async def compute_event_context( self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None ): """Build an EventContext structure for the event. @@ -278,7 +271,7 @@ class StateHandler(object): # otherwise, we'll need to resolve the state across the prev_events. logger.debug("calling resolve_state_groups from compute_event_context") - entry = yield self.resolve_state_groups_for_events( + entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -295,7 +288,7 @@ class StateHandler(object): # if not state_group_before_event: - state_group_before_event = yield self.state_store.store_state_group( + state_group_before_event = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event_prev_group, @@ -335,7 +328,7 @@ class StateHandler(object): state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = yield self.state_store.store_state_group( + state_group_after_event = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event, @@ -353,8 +346,7 @@ class StateHandler(object): ) @measure_func() - @defer.inlineCallbacks - def resolve_state_groups_for_events(self, room_id, event_ids): + async def resolve_state_groups_for_events(self, room_id, event_ids): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -373,7 +365,7 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.state_store.get_state_groups_ids( + state_groups_ids = await self.state_store.get_state_groups_ids( room_id, event_ids ) @@ -382,7 +374,7 @@ class StateHandler(object): elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) + prev_group, delta_ids = await self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, @@ -391,9 +383,9 @@ class StateHandler(object): delta_ids=delta_ids, ) - room_version = yield self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version_id(room_id) - result = yield self._state_resolution_handler.resolve_state_groups( + result = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, state_groups_ids, @@ -402,8 +394,7 @@ class StateHandler(object): ) return result - @defer.inlineCallbacks - def resolve_events(self, room_version, state_sets, event): + async def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -414,7 +405,7 @@ class StateHandler(object): state_map = {ev.event_id: ev for st in state_sets for ev in st} with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_store( + new_state = await resolve_events_with_store( self.clock, event.room_id, room_version, @@ -451,9 +442,8 @@ class StateResolutionHandler(object): reset_expiry_on_get=True, ) - @defer.inlineCallbacks @log_function - def resolve_state_groups( + async def resolve_state_groups( self, room_id, room_version, state_groups_ids, event_map, state_res_store ): """Resolves conflicts between a set of state groups @@ -479,13 +469,13 @@ class StateResolutionHandler(object): state_res_store (StateResolutionStore) Returns: - Deferred[_StateCacheEntry]: resolved state + _StateCacheEntry: resolved state """ logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) group_names = frozenset(state_groups_ids.keys()) - with (yield self.resolve_linearizer.queue(group_names)): + with (await self.resolve_linearizer.queue(group_names)): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) if cache: @@ -517,7 +507,7 @@ class StateResolutionHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_store( + new_state = await resolve_events_with_store( self.clock, room_id, room_version, @@ -598,7 +588,7 @@ def resolve_events_with_store( state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", -): +) -> Awaitable[StateMap[str]]: """ Args: room_id: the room we are working in @@ -619,8 +609,7 @@ def resolve_events_with_store( state_res_store: a place to fetch events from Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 7b531a8337..ab5e24841d 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,9 +15,7 @@ import hashlib import logging -from typing import Callable, Dict, List, Optional - -from twisted.internet import defer +from typing import Awaitable, Callable, Dict, List, Optional from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,12 +30,11 @@ logger = logging.getLogger(__name__) POWER_KEY = (EventTypes.PowerLevels, "") -@defer.inlineCallbacks -def resolve_events_with_store( +async def resolve_events_with_store( room_id: str, state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable, + state_map_factory: Callable[[List[str]], Awaitable], ): """ Args: @@ -56,7 +53,7 @@ def resolve_events_with_store( state_map_factory: will be called with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. + an Awaitable that resolves to a dict of event_id to event. Returns: Deferred[dict[(str, str), str]]: @@ -80,7 +77,7 @@ def resolve_events_with_store( # dict[str, FrozenEvent]: a map from state event id to event. Only includes # the state events which are in conflict (and those in event_map) - state_map = yield state_map_factory(needed_events) + state_map = await state_map_factory(needed_events) if event_map is not None: state_map.update(event_map) @@ -110,7 +107,7 @@ def resolve_events_with_store( "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count ) - state_map_new = yield state_map_factory(new_needed_events) + state_map_new = await state_map_factory(new_needed_events) for event in state_map_new.values(): if event.room_id != room_id: raise Exception( diff --git a/synapse/state/v2.py b/synapse/state/v2.py index bf6caa0946..6634955cdc 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -18,8 +18,6 @@ import itertools import logging from typing import Dict, List, Optional -from twisted.internet import defer - import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,14 +30,13 @@ from synapse.util import Clock logger = logging.getLogger(__name__) -# We want to yield to the reactor occasionally during state res when dealing +# We want to await to the reactor occasionally during state res when dealing # with large data sets, so that we don't exhaust the reactor. This is done by -# yielding to reactor during loops every N iterations. -_YIELD_AFTER_ITERATIONS = 100 +# awaiting to reactor during loops every N iterations. +_AWAIT_AFTER_ITERATIONS = 100 -@defer.inlineCallbacks -def resolve_events_with_store( +async def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, @@ -87,7 +84,7 @@ def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) + auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) full_conflicted_set = set( itertools.chain( @@ -95,7 +92,7 @@ def resolve_events_with_store( ) ) - events = yield state_res_store.get_events( + events = await state_res_store.get_events( [eid for eid in full_conflicted_set if eid not in event_map], allow_rejected=True, ) @@ -118,14 +115,14 @@ def resolve_events_with_store( eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) - sorted_power_events = yield _reverse_topological_power_sort( + sorted_power_events = await _reverse_topological_power_sort( clock, room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( clock, room_id, room_version, @@ -148,13 +145,13 @@ def resolve_events_with_store( logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) - leftover_events = yield _mainline_sort( + leftover_events = await _mainline_sort( clock, room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( clock, room_id, room_version, @@ -174,8 +171,7 @@ def resolve_events_with_store( return resolved_state -@defer.inlineCallbacks -def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): +async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. @@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(room_id, event_id, event_map, state_res_store) + event = await _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""): @@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): return int(level) -@defer.inlineCallbacks -def _get_auth_chain_difference(state_sets, event_map, state_res_store): +async def _get_auth_chain_difference(state_sets, event_map, state_res_store): """Compare the auth chains of each state set and return the set of events that only appear in some but not all of the auth chains. @@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Deferred[set[str]]: Set of event IDs """ - difference = yield state_res_store.get_auth_chain_difference( + difference = await state_res_store.get_auth_chain_difference( [set(state_set.values()) for state_set in state_sets] ) @@ -292,8 +287,7 @@ def _is_power_event(event): return False -@defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph( +async def _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event @@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(room_id, eid, event_map, state_res_store) + event = await _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph( graph.setdefault(eid, set()).add(aid) -@defer.inlineCallbacks -def _reverse_topological_power_sort( +async def _reverse_topological_power_sort( clock, room_id, event_ids, event_map, state_res_store, auth_diff ): """Returns a list of the event_ids sorted by reverse topological ordering, @@ -344,26 +337,26 @@ def _reverse_topological_power_sort( graph = {} for idx, event_id in enumerate(event_ids, start=1): - yield _add_event_and_auth_chain_to_graph( + await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ) - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) event_to_pl = {} for idx, event_id in enumerate(graph, start=1): - pl = yield _get_power_level_for_sender( + pl = await _get_power_level_for_sender( room_id, event_id, event_map, state_res_store ) event_to_pl[event_id] = pl - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) def _get_power_order(event_id): ev = event_map[event_id] @@ -378,8 +371,7 @@ def _reverse_topological_power_sort( return sorted_events -@defer.inlineCallbacks -def _iterative_auth_checks( +async def _iterative_auth_checks( clock, room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the @@ -405,7 +397,7 @@ def _iterative_auth_checks( auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) @@ -420,7 +412,7 @@ def _iterative_auth_checks( for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(room_id, ev_id, event_map, state_res_store) + ev = await _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] @@ -438,16 +430,15 @@ def _iterative_auth_checks( except AuthError: pass - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) return resolved_state -@defer.inlineCallbacks -def _mainline_sort( +async def _mainline_sort( clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store ): """Returns a sorted list of event_ids sorted by mainline ordering based on @@ -474,21 +465,21 @@ def _mainline_sort( idx = 0 while pl: mainline.append(pl) - pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) + pl_ev = await _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) idx += 1 @@ -498,23 +489,24 @@ def _mainline_sort( order_map = {} for idx, ev_id in enumerate(event_ids, start=1): - depth = yield _get_mainline_depth_for_event( + depth = await _get_mainline_depth_for_event( event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) event_ids.sort(key=lambda ev_id: order_map[ev_id]) return event_ids -@defer.inlineCallbacks -def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): +async def _get_mainline_depth_for_event( + event, mainline_map, event_map, state_res_store +): """Get the mainline depths for the given event based on the mainline map Args: @@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor return 0 -@defer.inlineCallbacks -def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): +async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): """Helper function to look up event in event_map, falling back to looking it up in the store @@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): Deferred[Optional[FrozenEvent]] """ if event_id not in event_map: - events = yield state_res_store.get_events([event_id], allow_rejected=True) + events = await state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) event = event_map.get(event_id) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index d181488db7..c229248101 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -259,7 +259,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 29765890ee..a92e401e88 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 6b8130bf0f..942e51fd3a 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): room_id ) - users_with_profile = yield state.get_current_users_in_room(room_id) + users_with_profile = yield defer.ensureDeferred( + state.get_current_users_in_room(room_id) + ) user_ids = set(users_with_profile) # Update each user in the user directory. diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index fa46041676..78fbdcdee8 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -29,7 +29,6 @@ from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main.events import DeltaState from synapse.types import StateMap @@ -648,6 +647,10 @@ class EventsPersistenceStorage(object): room_version = await self.main_store.get_room_version_id(room_id) logger.debug("calling resolve_state_groups from preserve_events") + + # Avoid a circular import. + from synapse.state import StateResolutionStore + res = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 1a9bd5f37d..d1bd18da39 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -26,21 +26,24 @@ from synapse.rest import admin from synapse.rest.client.v1 import login from synapse.types import JsonDict, ReadReceipt +from tests.test_utils import make_awaitable from tests.unittest import HomeserverTestCase, override_config class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): + mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) + # Ensure a new Awaitable is created for each call. + mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable( + ["test", "host2"] + ) return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), + state_handler=mock_state_handler, federation_transport_client=Mock(spec=["send_transaction"]), ) @override_config({"send_federation": True}) def test_send_receipts(self): - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) @@ -81,9 +84,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) @@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) @@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - # stub out get_current_hosts_in_room - mock_state_handler = hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - # stub out get_users_who_share_room_with_user so that it claims that # `@user2:host2` is in the room def get_users_who_share_room_with_user(user_id): diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 38f9b423ef..f2955a9c69 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -14,6 +14,7 @@ # limitations under the License. import itertools +from typing import List import attr @@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(event_map), ) - state_before = self.successResultOf(state_d) + state_before = self.successResultOf(defer.ensureDeferred(state_d)) state_after = dict(state_before) if fake_event.state_key is not None: @@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(self.event_map), ) - state = self.successResultOf(state_d) + state = self.successResultOf(defer.ensureDeferred(state_d)) self.assert_dict(self.expected_combined_state, state) @@ -608,9 +609,11 @@ class TestStateResolutionStore(object): Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. """ - return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + return defer.succeed( + {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + ) - def _get_auth_chain(self, event_ids): + def _get_auth_chain(self, event_ids: List[str]) -> List[str]: """Gets the full auth chain for a set of events (including rejected events). @@ -622,10 +625,10 @@ class TestStateResolutionStore(object): presence of rejected events Args: - event_ids (list): The event IDs of the events to fetch the auth + event_ids: The event IDs of the events to fetch the auth chain for. Must be state events. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + List of event IDs of the auth chain. """ # Simple DFS for auth chain @@ -648,4 +651,4 @@ class TestStateResolutionStore(object): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) - return set(chains[0]).union(*chains[1:]) - common + return defer.succeed(set(chains[0]).union(*chains[1:]) - common) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index b1dceb2918..1d77b4a2d6 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -109,7 +109,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Name, name=name, content={"name": name}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( @@ -125,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( diff --git a/tests/test_state.py b/tests/test_state.py index 66f22f6813..4858e8fc59 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -97,17 +97,19 @@ class StateGroupStore(object): self._group_to_state[state_group] = dict(current_state_ids) - return state_group + return defer.succeed(state_group) def get_events(self, event_ids, **kwargs): - return { - e_id: self._event_id_to_event[e_id] - for e_id in event_ids - if e_id in self._event_id_to_event - } + return defer.succeed( + { + e_id: self._event_id_to_event[e_id] + for e_id in event_ids + if e_id in self._event_id_to_event + } + ) def get_state_group_delta(self, name): - return None, None + return defer.succeed((None, None)) def register_events(self, events): for e in events: @@ -120,7 +122,7 @@ class StateGroupStore(object): self._event_to_state_group[event_id] = state_group def get_room_version_id(self, room_id): - return RoomVersions.V1.identifier + return defer.succeed(RoomVersions.V1.identifier) class DictObj(dict): @@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase): context_store = {} # type: dict[str, EventContext] for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( + group_name = yield self.store.store_state_group( prev_event_id, event.room_id, None, @@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual( {e.event_id for e in old_state}, set(current_state_ids.values()) @@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( + group_name = yield self.store.store_state_group( prev_event_id, event.room_id, None, @@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) prev_state_ids = yield context.get_prev_state_ids() @@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) + @defer.inlineCallbacks def _get_context( self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): - sg1 = self.store.store_state_group( + sg1 = yield self.store.store_state_group( prev_event_id_1, event.room_id, None, @@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id_1, sg1) - sg2 = self.store.store_state_group( + sg2 = yield self.store.store_state_group( prev_event_id_2, event.room_id, None, @@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id_2, sg2) - return self.state.compute_event_context(event) + result = yield defer.ensureDeferred(self.state.compute_event_context(event)) + return result diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 7b345b03bb..508aeba078 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -17,7 +17,7 @@ """ Utilities for running the unit tests """ -from typing import Awaitable, TypeVar +from typing import Any, Awaitable, TypeVar TV = TypeVar("TV") @@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: # if next didn't raise, the awaitable hasn't completed. raise Exception("awaitable has not yet completed") + + +async def make_awaitable(result: Any): + """Create an awaitable that just returns a result.""" + return result -- cgit 1.5.1 From 3fc8fdd150e2471d6e96b842e364d9421066f4ba Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 07:50:44 -0400 Subject: Support oEmbed for media previews. (#7920) Fixes previews of Twitter URLs by using their oEmbed endpoint to grab content. --- changelog.d/7920.feature | 1 + synapse/rest/media/v1/preview_url_resource.py | 265 +++++++++++++++++++++----- tests/rest/media/v1/test_url_preview.py | 142 +++++++++++++- 3 files changed, 355 insertions(+), 53 deletions(-) create mode 100644 changelog.d/7920.feature (limited to 'tests') diff --git a/changelog.d/7920.feature b/changelog.d/7920.feature new file mode 100644 index 0000000000..4093f5d329 --- /dev/null +++ b/changelog.d/7920.feature @@ -0,0 +1 @@ +Support oEmbed for media previews. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index e52c86c798..13d1a6d2ed 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -26,6 +26,7 @@ import traceback from typing import Dict, Optional from urllib import parse as urlparse +import attr from canonicaljson import json from twisted.internet import defer @@ -56,6 +57,65 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) OG_TAG_NAME_MAXLEN = 50 OG_TAG_VALUE_MAXLEN = 1000 +ONE_HOUR = 60 * 60 * 1000 + +# A map of globs to API endpoints. +_oembed_globs = { + # Twitter. + "https://publish.twitter.com/oembed": [ + "https://twitter.com/*/status/*", + "https://*.twitter.com/*/status/*", + "https://twitter.com/*/moments/*", + "https://*.twitter.com/*/moments/*", + # Include the HTTP versions too. + "http://twitter.com/*/status/*", + "http://*.twitter.com/*/status/*", + "http://twitter.com/*/moments/*", + "http://*.twitter.com/*/moments/*", + ], +} +# Convert the globs to regular expressions. +_oembed_patterns = {} +for endpoint, globs in _oembed_globs.items(): + for glob in globs: + # Convert the glob into a sane regular expression to match against. The + # rules followed will be slightly different for the domain portion vs. + # the rest. + # + # 1. The scheme must be one of HTTP / HTTPS (and have no globs). + # 2. The domain can have globs, but we limit it to characters that can + # reasonably be a domain part. + # TODO: This does not attempt to handle Unicode domain names. + # 3. Other parts allow a glob to be any one, or more, characters. + results = urlparse.urlparse(glob) + + # Ensure the scheme does not have wildcards (and is a sane scheme). + if results.scheme not in {"http", "https"}: + raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,)) + + pattern = urlparse.urlunparse( + [ + results.scheme, + re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"), + ] + + [re.escape(part).replace("\\*", ".+") for part in results[2:]] + ) + _oembed_patterns[re.compile(pattern)] = endpoint + + +@attr.s +class OEmbedResult: + # Either HTML content or URL must be provided. + html = attr.ib(type=Optional[str]) + url = attr.ib(type=Optional[str]) + title = attr.ib(type=Optional[str]) + # Number of seconds to cache the content. + cache_age = attr.ib(type=int) + + +class OEmbedError(Exception): + """An error occurred processing the oEmbed object.""" + class PreviewUrlResource(DirectServeJsonResource): isLeaf = True @@ -99,7 +159,7 @@ class PreviewUrlResource(DirectServeJsonResource): cache_name="url_previews", clock=self.clock, # don't spider URLs more often than once an hour - expiry_ms=60 * 60 * 1000, + expiry_ms=ONE_HOUR, ) if self._worker_run_media_background_jobs: @@ -310,6 +370,87 @@ class PreviewUrlResource(DirectServeJsonResource): return jsonog.encode("utf8") + def _get_oembed_url(self, url: str) -> Optional[str]: + """ + Check whether the URL should be downloaded as oEmbed content instead. + + Params: + url: The URL to check. + + Returns: + A URL to use instead or None if the original URL should be used. + """ + for url_pattern, endpoint in _oembed_patterns.items(): + if url_pattern.fullmatch(url): + return endpoint + + # No match. + return None + + async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult: + """ + Request content from an oEmbed endpoint. + + Params: + endpoint: The oEmbed API endpoint. + url: The URL to pass to the API. + + Returns: + An object representing the metadata returned. + + Raises: + OEmbedError if fetching or parsing of the oEmbed information fails. + """ + try: + logger.debug("Trying to get oEmbed content for url '%s'", url) + result = await self.client.get_json( + endpoint, + # TODO Specify max height / width. + # Note that only the JSON format is supported. + args={"url": url}, + ) + + # Ensure there's a version of 1.0. + if result.get("version") != "1.0": + raise OEmbedError("Invalid version: %s" % (result.get("version"),)) + + oembed_type = result.get("type") + + # Ensure the cache age is None or an int. + cache_age = result.get("cache_age") + if cache_age: + cache_age = int(cache_age) + + oembed_result = OEmbedResult(None, None, result.get("title"), cache_age) + + # HTML content. + if oembed_type == "rich": + oembed_result.html = result.get("html") + return oembed_result + + if oembed_type == "photo": + oembed_result.url = result.get("url") + return oembed_result + + # TODO Handle link and video types. + + if "thumbnail_url" in result: + oembed_result.url = result.get("thumbnail_url") + return oembed_result + + raise OEmbedError("Incompatible oEmbed information.") + + except OEmbedError as e: + # Trap OEmbedErrors first so we can directly re-raise them. + logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) + raise + + except Exception as e: + # Trap any exception and let the code follow as usual. + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) + raise OEmbedError() from e + async def _download_url(self, url, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a @@ -319,54 +460,90 @@ class PreviewUrlResource(DirectServeJsonResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) - with self.media_storage.store_into_file(file_info) as (f, fname, finish): + # If this URL can be accessed via oEmbed, use that instead. + url_to_download = url + oembed_url = self._get_oembed_url(url) + if oembed_url: + # The result might be a new URL to download, or it might be HTML content. try: - logger.debug("Trying to get preview for url '%s'", url) - length, headers, uri, code = await self.client.get_file( - url, - output_stream=f, - max_size=self.max_spider_size, - headers={"Accept-Language": self.url_preview_accept_language}, - ) - except SynapseError: - # Pass SynapseErrors through directly, so that the servlet - # handler will return a SynapseError to the client instead of - # blank data or a 500. - raise - except DNSLookupError: - # DNS lookup returned no results - # Note: This will also be the case if one of the resolved IP - # addresses is blacklisted - raise SynapseError( - 502, - "DNS resolution failure during URL preview generation", - Codes.UNKNOWN, - ) - except Exception as e: - # FIXME: pass through 404s and other error messages nicely - logger.warning("Error downloading %s: %r", url, e) + oembed_result = await self._get_oembed_content(oembed_url, url) + if oembed_result.url: + url_to_download = oembed_result.url + elif oembed_result.html: + url_to_download = None + except OEmbedError: + # If an error occurs, try doing a normal preview. + pass - raise SynapseError( - 500, - "Failed to download content: %s" - % (traceback.format_exception_only(sys.exc_info()[0], e),), - Codes.UNKNOWN, - ) - await finish() + if url_to_download: + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + try: + logger.debug("Trying to get preview for url '%s'", url_to_download) + length, headers, uri, code = await self.client.get_file( + url_to_download, + output_stream=f, + max_size=self.max_spider_size, + headers={"Accept-Language": self.url_preview_accept_language}, + ) + except SynapseError: + # Pass SynapseErrors through directly, so that the servlet + # handler will return a SynapseError to the client instead of + # blank data or a 500. + raise + except DNSLookupError: + # DNS lookup returned no results + # Note: This will also be the case if one of the resolved IP + # addresses is blacklisted + raise SynapseError( + 502, + "DNS resolution failure during URL preview generation", + Codes.UNKNOWN, + ) + except Exception as e: + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading %s: %r", url_to_download, e) + + raise SynapseError( + 500, + "Failed to download content: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), + Codes.UNKNOWN, + ) + await finish() + + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" + + download_name = get_filename_from_headers(headers) + + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + expires = ONE_HOUR + etag = headers["ETag"][0] if "ETag" in headers else None + else: + html_bytes = oembed_result.html.encode("utf-8") # type: ignore + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + f.write(html_bytes) + await finish() + + media_type = "text/html" + download_name = oembed_result.title + length = len(html_bytes) + # If a specific cache age was not given, assume 1 hour. + expires = oembed_result.cache_age or ONE_HOUR + uri = oembed_url + code = 200 + etag = None try: - if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode("ascii") - else: - media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() - download_name = get_filename_from_headers(headers) - await self.store.store_local_media( media_id=file_id, media_type=media_type, - time_now_ms=self.clock.time_msec(), + time_now_ms=time_now_ms, upload_name=download_name, media_length=length, user_id=user, @@ -389,10 +566,8 @@ class PreviewUrlResource(DirectServeJsonResource): "filename": fname, "uri": uri, "response_code": code, - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - "expires": 60 * 60 * 1000, - "etag": headers["ETag"][0] if "ETag" in headers else None, + "expires": expires, + "etag": etag, } def _start_expire_url_cache_data(self): @@ -449,7 +624,7 @@ class PreviewUrlResource(DirectServeJsonResource): # These may be cached for a bit on the client (i.e., they # may have a room open with a preview url thing open). # So we wait a couple of days before deleting, just in case. - expire_before = now - 2 * 24 * 60 * 60 * 1000 + expire_before = now - 2 * 24 * ONE_HOUR media_ids = await self.store.get_url_cache_media_before(expire_before) removed_media = [] diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 2826211f32..74765a582b 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -12,8 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import os +import re + +from mock import patch import attr @@ -131,7 +134,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.reactor.nameResolver = Resolver() def test_cache_returns_correct_type(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( "GET", "url_preview?url=http://matrix.org", shorthand=False @@ -187,7 +190,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def test_non_ascii_preview_httpequiv(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"" @@ -221,7 +224,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_non_ascii_preview_content_type(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"" @@ -254,7 +257,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_overlong_title(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"" @@ -292,7 +295,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ IP addresses can be previewed directly. """ - self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( "GET", "url_preview?url=http://example.com", shorthand=False @@ -439,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Hardcode the URL resolving to the IP we want. self.lookups["example.com"] = [ (IPv4Address, "1.1.1.2"), - (IPv4Address, "8.8.8.8"), + (IPv4Address, "10.1.2.3"), ] request, channel = self.make_request( @@ -518,7 +521,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ Accept-Language header is sent to the remote server """ - self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] # Build and make a request to the server request, channel = self.make_request( @@ -562,3 +565,126 @@ class URLPreviewTests(unittest.HomeserverTestCase): ), server.data, ) + + def test_oembed_photo(self): + """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" + # Route the HTTP version to an HTTP endpoint so that the tests work. + with patch.dict( + "synapse.rest.media.v1.preview_url_resource._oembed_patterns", + { + re.compile( + r"http://twitter\.com/.+/status/.+" + ): "http://publish.twitter.com/oembed", + }, + clear=True, + ): + + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = { + "version": "1.0", + "type": "photo", + "url": "http://cdn.twitter.com/matrixdotorg", + } + oembed_content = json.dumps(result).encode("utf-8") + + end_content = ( + b"" + b"Some Title" + b'' + b"" + ) + + request, channel = self.make_request( + "GET", + "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' + ) + % (len(oembed_content),) + + oembed_content + ) + + self.pump() + + client = self.reactor.tcpClients[1][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "Some Title", "og:description": "hi"} + ) + + def test_oembed_rich(self): + """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" + # Route the HTTP version to an HTTP endpoint so that the tests work. + with patch.dict( + "synapse.rest.media.v1.preview_url_resource._oembed_patterns", + { + re.compile( + r"http://twitter\.com/.+/status/.+" + ): "http://publish.twitter.com/oembed", + }, + clear=True, + ): + + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = { + "version": "1.0", + "type": "rich", + "html": "
Content Preview
", + } + end_content = json.dumps(result).encode("utf-8") + + request, channel = self.make_request( + "GET", + "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, + {"og:title": None, "og:description": "Content Preview"}, + ) -- cgit 1.5.1 From c4268e3da64f1abb5b31deaeb5769adb6510c0a7 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 27 Jul 2020 05:22:52 -0700 Subject: Convert tests/rest/admin/test_room.py to unix file endings (#7953) Converts tests/rest/admin/test_room.py to have unix file endings after they were accidentally changed in #7613. Keeping the same changelog as #7613 as it hasn't gone out in a release yet. --- changelog.d/7953.feature | 1 + tests/rest/admin/test_room.py | 2894 ++++++++++++++++++++--------------------- 2 files changed, 1448 insertions(+), 1447 deletions(-) create mode 100644 changelog.d/7953.feature (limited to 'tests') diff --git a/changelog.d/7953.feature b/changelog.d/7953.feature new file mode 100644 index 0000000000..945b5c743c --- /dev/null +++ b/changelog.d/7953.feature @@ -0,0 +1 @@ +Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms//delete`). Contributed by @dklimpel. \ No newline at end of file diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 946f06d151..ba8552c29f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1,1447 +1,1447 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Dirk Klimpel -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import urllib.parse -from typing import List, Optional - -from mock import Mock - -import synapse.rest.admin -from synapse.api.errors import Codes -from synapse.rest.client.v1 import directory, events, login, room - -from tests import unittest - -"""Tests admin REST events for /rooms paths.""" - - -class ShutdownRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") - - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) - - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - """ - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([self.other_user], users_in_room) - - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" - - # Assert that the user is getting consent error - self.helper.send( - room_id, body="foo", tok=self.other_user_token, expect_code=403 - ) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert there is now no longer anyone in the room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) - - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - """ - - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_token, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert we can no longer peek into the room - self._assert_peek(room_id, expect_code=403) - - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ - - url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - -class DeleteRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_tok = self.login("user", "pass") - - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) - - self.room_id = self.helper.create_room_as( - self.other_user, tok=self.other_user_tok - ) - self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - - def test_requester_is_no_admin(self): - """ - If the user is not a server admin, an error 403 is returned. - """ - - request, channel = self.make_request( - "POST", self.url, json.dumps({}), access_token=self.other_user_tok, - ) - self.render(request) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - def test_room_does_not_exist(self): - """ - Check that unknown rooms/server return error 404. - """ - url = "/_synapse/admin/v1/rooms/!unknown:test/delete" - - request, channel = self.make_request( - "POST", url, json.dumps({}), access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_room_is_not_valid(self): - """ - Check that invalid room names, return an error 400. - """ - url = "/_synapse/admin/v1/rooms/invalidroom/delete" - - request, channel = self.make_request( - "POST", url, json.dumps({}), access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "invalidroom is not a legal room ID", channel.json_body["error"], - ) - - def test_new_room_user_does_not_exist(self): - """ - Tests that the user ID must be from local server but it does not have to exist. - """ - body = json.dumps({"new_room_user_id": "@unknown:test"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertIn("new_room_id", channel.json_body) - self.assertIn("kicked_users", channel.json_body) - self.assertIn("failed_to_kick_users", channel.json_body) - self.assertIn("local_aliases", channel.json_body) - - def test_new_room_user_is_not_local(self): - """ - Check that only local users can create new room to move members. - """ - body = json.dumps({"new_room_user_id": "@not:exist.bla"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "User must be our own: @not:exist.bla", channel.json_body["error"], - ) - - def test_block_is_not_bool(self): - """ - If parameter `block` is not boolean, return an error - """ - body = json.dumps({"block": "NotBool"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) - - def test_purge_room_and_block(self): - """Test to purge a room and block it. - Members will not be moved to a new room and will not receive a message. - """ - # Test that room is not purged - with self.assertRaises(AssertionError): - self._is_purged(self.room_id) - - # Test that room is not blocked - self._is_blocked(self.room_id, expect=False) - - # Assert one user in room - self._is_member(room_id=self.room_id, user_id=self.other_user) - - body = json.dumps({"block": True}) - - request, channel = self.make_request( - "POST", - self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(None, channel.json_body["new_room_id"]) - self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) - self.assertIn("failed_to_kick_users", channel.json_body) - self.assertIn("local_aliases", channel.json_body) - - self._is_purged(self.room_id) - self._is_blocked(self.room_id, expect=True) - self._has_no_members(self.room_id) - - def test_purge_room_and_not_block(self): - """Test to purge a room and do not block it. - Members will not be moved to a new room and will not receive a message. - """ - # Test that room is not purged - with self.assertRaises(AssertionError): - self._is_purged(self.room_id) - - # Test that room is not blocked - self._is_blocked(self.room_id, expect=False) - - # Assert one user in room - self._is_member(room_id=self.room_id, user_id=self.other_user) - - body = json.dumps({"block": False}) - - request, channel = self.make_request( - "POST", - self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(None, channel.json_body["new_room_id"]) - self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) - self.assertIn("failed_to_kick_users", channel.json_body) - self.assertIn("local_aliases", channel.json_body) - - self._is_purged(self.room_id) - self._is_blocked(self.room_id, expect=False) - self._has_no_members(self.room_id) - - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - Members will be moved to a new room and will receive a message. - """ - self.event_creation_handler._block_events_without_consent_error = None - - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) - self.assertEqual([self.other_user], users_in_room) - - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" - - # Assert that the user is getting consent error - self.helper.send( - self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 - ) - - # Test that room is not purged - with self.assertRaises(AssertionError): - self._is_purged(self.room_id) - - # Assert one user in room - self._is_member(room_id=self.room_id, user_id=self.other_user) - - # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) - self.assertIn("new_room_id", channel.json_body) - self.assertIn("failed_to_kick_users", channel.json_body) - self.assertIn("local_aliases", channel.json_body) - - # Test that member has moved to new room - self._is_member( - room_id=channel.json_body["new_room_id"], user_id=self.other_user - ) - - self._is_purged(self.room_id) - self._has_no_members(self.room_id) - - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - Members will be moved to a new room and will receive a message. - """ - self.event_creation_handler._block_events_without_consent_error = None - - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that room is not purged - with self.assertRaises(AssertionError): - self._is_purged(self.room_id) - - # Assert one user in room - self._is_member(room_id=self.room_id, user_id=self.other_user) - - # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) - self.assertIn("new_room_id", channel.json_body) - self.assertIn("failed_to_kick_users", channel.json_body) - self.assertIn("local_aliases", channel.json_body) - - # Test that member has moved to new room - self._is_member( - room_id=channel.json_body["new_room_id"], user_id=self.other_user - ) - - self._is_purged(self.room_id) - self._has_no_members(self.room_id) - - # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=403) - - def _is_blocked(self, room_id, expect=True): - """Assert that the room is blocked or not - """ - d = self.store.is_room_blocked(room_id) - if expect: - self.assertTrue(self.get_success(d)) - else: - self.assertIsNone(self.get_success(d)) - - def _has_no_members(self, room_id): - """Assert there is now no longer anyone in the room - """ - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) - - def _is_member(self, room_id, user_id): - """Test that user is member of the room - """ - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertIn(user_id, users_in_room) - - def _is_purged(self, room_id): - """Test that the following tables have been purged of all rows related to the room. - """ - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "local_invites", - "room_account_data", - "room_tags", - # "state_groups", # Current impl leaves orphaned state groups around. - "state_groups_state", - ): - count = self.get_success( - self.store.db.simple_select_one_onecol( - table=table, - keyvalues={"room_id": room_id}, - retcol="COUNT(*)", - desc="test_purge_room", - ) - ) - - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) - - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ - - url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - -class PurgeRoomTestCase(unittest.HomeserverTestCase): - """Test /purge_room admin API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_purge_room(self): - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # All users have to have left the room. - self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) - - url = "/_synapse/admin/v1/purge_room" - request, channel = self.make_request( - "POST", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the following tables have been purged of all rows related to the room. - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "room_account_data", - "room_tags", - # "state_groups", # Current impl leaves orphaned state groups around. - "state_groups_state", - ): - count = self.get_success( - self.store.db.simple_select_one_onecol( - table=table, - keyvalues={"room_id": room_id}, - retcol="COUNT(*)", - desc="test_purge_room", - ) - ) - - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) - - -class RoomTestCase(unittest.HomeserverTestCase): - """Test /room admin API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - directory.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - # Create user - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_list_rooms(self): - """Test that we can list rooms""" - # Create 3 test rooms - total_rooms = 3 - room_ids = [] - for x in range(total_rooms): - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok - ) - room_ids.append(room_id) - - # Request the list of rooms - url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - - # Check request completed successfully - self.assertEqual(200, int(channel.code), msg=channel.json_body) - - # Check that response json body contains a "rooms" key - self.assertTrue( - "rooms" in channel.json_body, - msg="Response body does not " "contain a 'rooms' key", - ) - - # Check that 3 rooms were returned - self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) - - # Check their room_ids match - returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] - self.assertEqual(room_ids, returned_room_ids) - - # Check that all fields are available - for r in channel.json_body["rooms"]: - self.assertIn("name", r) - self.assertIn("canonical_alias", r) - self.assertIn("joined_members", r) - self.assertIn("joined_local_members", r) - self.assertIn("version", r) - self.assertIn("creator", r) - self.assertIn("encryption", r) - self.assertIn("federatable", r) - self.assertIn("public", r) - self.assertIn("join_rules", r) - self.assertIn("guest_access", r) - self.assertIn("history_visibility", r) - self.assertIn("state_events", r) - - # Check that the correct number of total rooms was returned - self.assertEqual(channel.json_body["total_rooms"], total_rooms) - - # Check that the offset is correct - # Should be 0 as we aren't paginating - self.assertEqual(channel.json_body["offset"], 0) - - # Check that the prev_batch parameter is not present - self.assertNotIn("prev_batch", channel.json_body) - - # We shouldn't receive a next token here as there's no further rooms to show - self.assertNotIn("next_batch", channel.json_body) - - def test_list_rooms_pagination(self): - """Test that we can get a full list of rooms through pagination""" - # Create 5 test rooms - total_rooms = 5 - room_ids = [] - for x in range(total_rooms): - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok - ) - room_ids.append(room_id) - - # Set the name of the rooms so we get a consistent returned ordering - for idx, room_id in enumerate(room_ids): - self.helper.send_state( - room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, - ) - - # Request the list of rooms - returned_room_ids = [] - start = 0 - limit = 2 - - run_count = 0 - should_repeat = True - while should_repeat: - run_count += 1 - - url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( - start, - limit, - "name", - ) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) - - self.assertTrue("rooms" in channel.json_body) - for r in channel.json_body["rooms"]: - returned_room_ids.append(r["room_id"]) - - # Check that the correct number of total rooms was returned - self.assertEqual(channel.json_body["total_rooms"], total_rooms) - - # Check that the offset is correct - # We're only getting 2 rooms each page, so should be 2 * last run_count - self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) - - if run_count > 1: - # Check the value of prev_batch is correct - self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) - - if "next_batch" not in channel.json_body: - # We have reached the end of the list - should_repeat = False - else: - # Make another query with an updated start value - start = channel.json_body["next_batch"] - - # We should've queried the endpoint 3 times - self.assertEqual( - run_count, - 3, - msg="Should've queried 3 times for 5 rooms with limit 2 per query", - ) - - # Check that we received all of the room ids - self.assertEqual(room_ids, returned_room_ids) - - url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - def test_correct_room_attributes(self): - """Test the correct attributes for a room are returned""" - # Create a test room - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - test_alias = "#test:test" - test_room_name = "something" - - # Have another user join the room - user_2 = self.register_user("user4", "pass") - user_tok_2 = self.login("user4", "pass") - self.helper.join(room_id, user_2, tok=user_tok_2) - - # Create a new alias to this room - url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Set this new alias as the canonical alias for this room - self.helper.send_state( - room_id, - "m.room.aliases", - {"aliases": [test_alias]}, - tok=self.admin_user_tok, - state_key="test", - ) - self.helper.send_state( - room_id, - "m.room.canonical_alias", - {"alias": test_alias}, - tok=self.admin_user_tok, - ) - - # Set a name for the room - self.helper.send_state( - room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, - ) - - # Request the list of rooms - url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check that only one room was returned - self.assertEqual(len(rooms), 1) - - # And that the value of the total_rooms key was correct - self.assertEqual(channel.json_body["total_rooms"], 1) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - # Check that all provided attributes are set - r = rooms[0] - self.assertEqual(room_id, r["room_id"]) - self.assertEqual(test_room_name, r["name"]) - self.assertEqual(test_alias, r["canonical_alias"]) - - def test_room_list_sort_order(self): - """Test room list sort ordering. alphabetical name versus number of members, - reversing the order, etc. - """ - - def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str): - # Create a new alias to this room - url = "/_matrix/client/r0/directory/room/%s" % ( - urllib.parse.quote(test_alias), - ) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - {"room_id": room_id}, - access_token=admin_user_tok, - ) - self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) - - # Set this new alias as the canonical alias for this room - self.helper.send_state( - room_id, - "m.room.aliases", - {"aliases": [test_alias]}, - tok=admin_user_tok, - state_key="test", - ) - self.helper.send_state( - room_id, - "m.room.canonical_alias", - {"alias": test_alias}, - tok=admin_user_tok, - ) - - def _order_test( - order_type: str, expected_room_list: List[str], reverse: bool = False, - ): - """Request the list of rooms in a certain order. Assert that order is what - we expect - - Args: - order_type: The type of ordering to give the server - expected_room_list: The list of room_ids in the order we expect to get - back from the server - """ - # Request the list of rooms in the given order - url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) - if reverse: - url += "&dir=b" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check for the correct total_rooms value - self.assertEqual(channel.json_body["total_rooms"], 3) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - # Check that rooms were returned in alphabetical order - returned_order = [r["room_id"] for r in rooms] - self.assertListEqual(expected_room_list, returned_order) # order is checked - - # Create 3 test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C - self.helper.send_state( - room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, - ) - - # Set room canonical room aliases - _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok) - _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok) - _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok) - - # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 - user_1 = self.register_user("bob1", "pass") - user_1_tok = self.login("bob1", "pass") - self.helper.join(room_id_2, user_1, tok=user_1_tok) - - user_2 = self.register_user("bob2", "pass") - user_2_tok = self.login("bob2", "pass") - self.helper.join(room_id_3, user_2, tok=user_2_tok) - - user_3 = self.register_user("bob3", "pass") - user_3_tok = self.login("bob3", "pass") - self.helper.join(room_id_3, user_3, tok=user_3_tok) - - # Test different sort orders, with forward and reverse directions - _order_test("name", [room_id_1, room_id_2, room_id_3]) - _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3]) - _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("joined_members", [room_id_3, room_id_2, room_id_1]) - _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1]) - _order_test( - "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True - ) - - _order_test("version", [room_id_1, room_id_2, room_id_3]) - _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("creator", [room_id_1, room_id_2, room_id_3]) - _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("encryption", [room_id_1, room_id_2, room_id_3]) - _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("federatable", [room_id_1, room_id_2, room_id_3]) - _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("public", [room_id_1, room_id_2, room_id_3]) - # Different sort order of SQlite and PostreSQL - # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("join_rules", [room_id_1, room_id_2, room_id_3]) - _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("guest_access", [room_id_1, room_id_2, room_id_3]) - _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("history_visibility", [room_id_1, room_id_2, room_id_3]) - _order_test( - "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True - ) - - _order_test("state_events", [room_id_3, room_id_2, room_id_1]) - _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) - - def test_search_term(self): - """Test that searching for a room works correctly""" - # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - room_name_1 = "something" - room_name_2 = "else" - - # Set the name for each room - self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, - ) - - def _search_test( - expected_room_id: Optional[str], - search_term: str, - expected_http_code: int = 200, - ): - """Search for a room and check that the returned room's id is a match - - Args: - expected_room_id: The room_id expected to be returned by the API. Set - to None to expect zero results for the search - search_term: The term to search for room names with - expected_http_code: The expected http code for the request - """ - url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - - if expected_http_code != 200: - return - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check that the expected number of rooms were returned - expected_room_count = 1 if expected_room_id else 0 - self.assertEqual(len(rooms), expected_room_count) - self.assertEqual(channel.json_body["total_rooms"], expected_room_count) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - if expected_room_id: - # Check that the first returned room id is correct - r = rooms[0] - self.assertEqual(expected_room_id, r["room_id"]) - - # Perform search tests - _search_test(room_id_1, "something") - _search_test(room_id_1, "thing") - - _search_test(room_id_2, "else") - _search_test(room_id_2, "se") - - _search_test(None, "foo") - _search_test(None, "bar") - _search_test(None, "", expected_http_code=400) - - def test_single_room(self): - """Test that a single room can be requested correctly""" - # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - room_name_1 = "something" - room_name_2 = "else" - - # Set the name for each room - self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, - ) - - url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - self.assertIn("room_id", channel.json_body) - self.assertIn("name", channel.json_body) - self.assertIn("canonical_alias", channel.json_body) - self.assertIn("joined_members", channel.json_body) - self.assertIn("joined_local_members", channel.json_body) - self.assertIn("version", channel.json_body) - self.assertIn("creator", channel.json_body) - self.assertIn("encryption", channel.json_body) - self.assertIn("federatable", channel.json_body) - self.assertIn("public", channel.json_body) - self.assertIn("join_rules", channel.json_body) - self.assertIn("guest_access", channel.json_body) - self.assertIn("history_visibility", channel.json_body) - self.assertIn("state_events", channel.json_body) - - self.assertEqual(room_id_1, channel.json_body["room_id"]) - - def test_room_members(self): - """Test that room members can be requested correctly""" - # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # Have another user join the room - user_1 = self.register_user("foo", "pass") - user_tok_1 = self.login("foo", "pass") - self.helper.join(room_id_1, user_1, tok=user_tok_1) - - # Have another user join the room - user_2 = self.register_user("bar", "pass") - user_tok_2 = self.login("bar", "pass") - self.helper.join(room_id_1, user_2, tok=user_tok_2) - self.helper.join(room_id_2, user_2, tok=user_tok_2) - - # Have another user join the room - user_3 = self.register_user("foobar", "pass") - user_tok_3 = self.login("foobar", "pass") - self.helper.join(room_id_2, user_3, tok=user_tok_3) - - url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - self.assertCountEqual( - ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] - ) - self.assertEqual(channel.json_body["total"], 3) - - url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - self.assertCountEqual( - ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] - ) - self.assertEqual(channel.json_body["total"], 3) - - -class JoinAliasRoomTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.creator = self.register_user("creator", "test") - self.creator_tok = self.login("creator", "test") - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - - self.public_room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=True - ) - self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) - - def test_requester_is_no_admin(self): - """ - If the user is not a server admin, an error 403 is returned. - """ - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.second_tok, - ) - self.render(request) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - def test_invalid_parameter(self): - """ - If a parameter is missing, return an error - """ - body = json.dumps({"unknown_parameter": "@unknown:test"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - - def test_local_user_does_not_exist(self): - """ - Tests that a lookup for a user that does not exist returns a 404 - """ - body = json.dumps({"user_id": "@unknown:test"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_remote_user(self): - """ - Check that only local user can join rooms. - """ - body = json.dumps({"user_id": "@not:exist.bla"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "This endpoint can only be used with local users", - channel.json_body["error"], - ) - - def test_room_does_not_exist(self): - """ - Check that unknown rooms/server return error 404. - """ - body = json.dumps({"user_id": self.second_user_id}) - url = "/_synapse/admin/v1/join/!unknown:test" - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("No known servers", channel.json_body["error"]) - - def test_room_is_not_valid(self): - """ - Check that invalid room names, return an error 400. - """ - body = json.dumps({"user_id": self.second_user_id}) - url = "/_synapse/admin/v1/join/invalidroom" - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "invalidroom was not legal room ID or room alias", - channel.json_body["error"], - ) - - def test_join_public_room(self): - """ - Test joining a local user to a public room with "JoinRules.PUBLIC" - """ - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(self.public_room_id, channel.json_body["room_id"]) - - # Validate if user is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) - - def test_join_private_room_if_not_member(self): - """ - Test joining a local user to a private room with "JoinRules.INVITE" - when server admin is not member of this room. - """ - private_room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=False - ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - def test_join_private_room_if_member(self): - """ - Test joining a local user to a private room with "JoinRules.INVITE", - when server admin is member of this room. - """ - private_room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=False - ) - self.helper.invite( - room=private_room_id, - src=self.creator, - targ=self.admin_user, - tok=self.creator_tok, - ) - self.helper.join( - room=private_room_id, user=self.admin_user, tok=self.admin_user_tok - ) - - # Validate if server admin is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - - # Join user to room. - - url = "/_synapse/admin/v1/join/{}".format(private_room_id) - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["room_id"]) - - # Validate if user is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - - def test_join_private_room_if_owner(self): - """ - Test joining a local user to a private room with "JoinRules.INVITE", - when server admin is owner of this room. - """ - private_room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok, is_public=False - ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["room_id"]) - - # Validate if user is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import urllib.parse +from typing import List, Optional + +from mock import Mock + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import directory, events, login, room + +from tests import unittest + +"""Tests admin REST events for /rooms paths.""" + + +class ShutdownRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + """ + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + room_id, body="foo", tok=self.other_user_token, expect_code=403 + ) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert there is now no longer anyone in the room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + """ + + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_token, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert we can no longer peek into the room + self._assert_peek(room_id, expect_code=403) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + +class DeleteRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "POST", self.url, json.dumps({}), access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + url = "/_synapse/admin/v1/rooms/!unknown:test/delete" + + request, channel = self.make_request( + "POST", url, json.dumps({}), access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + url = "/_synapse/admin/v1/rooms/invalidroom/delete" + + request, channel = self.make_request( + "POST", url, json.dumps({}), access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom is not a legal room ID", channel.json_body["error"], + ) + + def test_new_room_user_does_not_exist(self): + """ + Tests that the user ID must be from local server but it does not have to exist. + """ + body = json.dumps({"new_room_user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("kicked_users", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + def test_new_room_user_is_not_local(self): + """ + Check that only local users can create new room to move members. + """ + body = json.dumps({"new_room_user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "User must be our own: @not:exist.bla", channel.json_body["error"], + ) + + def test_block_is_not_bool(self): + """ + If parameter `block` is not boolean, return an error + """ + body = json.dumps({"block": "NotBool"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_room_and_block(self): + """Test to purge a room and block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": True}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_purge_room_and_not_block(self): + """Test to purge a room and do not block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": False}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + ) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["new_room_id"], user_id=self.other_user + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["new_room_id"], user_id=self.other_user + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + # Assert we can no longer peek into the room + self._assert_peek(self.room_id, expect_code=403) + + def _is_blocked(self, room_id, expect=True): + """Assert that the room is blocked or not + """ + d = self.store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _has_no_members(self, room_id): + """Assert there is now no longer anyone in the room + """ + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def _is_member(self, room_id, user_id): + """Test that user is member of the room + """ + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertIn(user_id, users_in_room) + + def _is_purged(self, room_id): + """Test that the following tables have been purged of all rows related to the room. + """ + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + +class PurgeRoomTestCase(unittest.HomeserverTestCase): + """Test /purge_room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_purge_room(self): + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # All users have to have left the room. + self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/purge_room" + request, channel = self.make_request( + "POST", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the following tables have been purged of all rows related to the room. + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + +class RoomTestCase(unittest.HomeserverTestCase): + """Test /room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_list_rooms(self): + """Test that we can list rooms""" + # Create 3 test rooms + total_rooms = 3 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + + # Check request completed successfully + self.assertEqual(200, int(channel.code), msg=channel.json_body) + + # Check that response json body contains a "rooms" key + self.assertTrue( + "rooms" in channel.json_body, + msg="Response body does not " "contain a 'rooms' key", + ) + + # Check that 3 rooms were returned + self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) + + # Check their room_ids match + returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] + self.assertEqual(room_ids, returned_room_ids) + + # Check that all fields are available + for r in channel.json_body["rooms"]: + self.assertIn("name", r) + self.assertIn("canonical_alias", r) + self.assertIn("joined_members", r) + self.assertIn("joined_local_members", r) + self.assertIn("version", r) + self.assertIn("creator", r) + self.assertIn("encryption", r) + self.assertIn("federatable", r) + self.assertIn("public", r) + self.assertIn("join_rules", r) + self.assertIn("guest_access", r) + self.assertIn("history_visibility", r) + self.assertIn("state_events", r) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # Should be 0 as we aren't paginating + self.assertEqual(channel.json_body["offset"], 0) + + # Check that the prev_batch parameter is not present + self.assertNotIn("prev_batch", channel.json_body) + + # We shouldn't receive a next token here as there's no further rooms to show + self.assertNotIn("next_batch", channel.json_body) + + def test_list_rooms_pagination(self): + """Test that we can get a full list of rooms through pagination""" + # Create 5 test rooms + total_rooms = 5 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Set the name of the rooms so we get a consistent returned ordering + for idx, room_id in enumerate(room_ids): + self.helper.send_state( + room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + returned_room_ids = [] + start = 0 + limit = 2 + + run_count = 0 + should_repeat = True + while should_repeat: + run_count += 1 + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( + start, + limit, + "name", + ) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + self.assertTrue("rooms" in channel.json_body) + for r in channel.json_body["rooms"]: + returned_room_ids.append(r["room_id"]) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # We're only getting 2 rooms each page, so should be 2 * last run_count + self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) + + if run_count > 1: + # Check the value of prev_batch is correct + self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) + + if "next_batch" not in channel.json_body: + # We have reached the end of the list + should_repeat = False + else: + # Make another query with an updated start value + start = channel.json_body["next_batch"] + + # We should've queried the endpoint 3 times + self.assertEqual( + run_count, + 3, + msg="Should've queried 3 times for 5 rooms with limit 2 per query", + ) + + # Check that we received all of the room ids + self.assertEqual(room_ids, returned_room_ids) + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_correct_room_attributes(self): + """Test the correct attributes for a room are returned""" + # Create a test room + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + test_alias = "#test:test" + test_room_name = "something" + + # Have another user join the room + user_2 = self.register_user("user4", "pass") + user_tok_2 = self.login("user4", "pass") + self.helper.join(room_id, user_2, tok=user_tok_2) + + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=self.admin_user_tok, + ) + + # Set a name for the room + self.helper.send_state( + room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that only one room was returned + self.assertEqual(len(rooms), 1) + + # And that the value of the total_rooms key was correct + self.assertEqual(channel.json_body["total_rooms"], 1) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that all provided attributes are set + r = rooms[0] + self.assertEqual(room_id, r["room_id"]) + self.assertEqual(test_room_name, r["name"]) + self.assertEqual(test_alias, r["canonical_alias"]) + + def test_room_list_sort_order(self): + """Test room list sort ordering. alphabetical name versus number of members, + reversing the order, etc. + """ + + def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str): + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % ( + urllib.parse.quote(test_alias), + ) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=admin_user_tok, + ) + + def _order_test( + order_type: str, expected_room_list: List[str], reverse: bool = False, + ): + """Request the list of rooms in a certain order. Assert that order is what + we expect + + Args: + order_type: The type of ordering to give the server + expected_room_list: The list of room_ids in the order we expect to get + back from the server + """ + # Request the list of rooms in the given order + url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) + if reverse: + url += "&dir=b" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check for the correct total_rooms value + self.assertEqual(channel.json_body["total_rooms"], 3) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that rooms were returned in alphabetical order + returned_order = [r["room_id"] for r in rooms] + self.assertListEqual(expected_room_list, returned_order) # order is checked + + # Create 3 test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C + self.helper.send_state( + room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, + ) + + # Set room canonical room aliases + _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok) + + # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 + user_1 = self.register_user("bob1", "pass") + user_1_tok = self.login("bob1", "pass") + self.helper.join(room_id_2, user_1, tok=user_1_tok) + + user_2 = self.register_user("bob2", "pass") + user_2_tok = self.login("bob2", "pass") + self.helper.join(room_id_3, user_2, tok=user_2_tok) + + user_3 = self.register_user("bob3", "pass") + user_3_tok = self.login("bob3", "pass") + self.helper.join(room_id_3, user_3, tok=user_3_tok) + + # Test different sort orders, with forward and reverse directions + _order_test("name", [room_id_1, room_id_2, room_id_3]) + _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3]) + _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("joined_members", [room_id_3, room_id_2, room_id_1]) + _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1]) + _order_test( + "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("version", [room_id_1, room_id_2, room_id_3]) + _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("creator", [room_id_1, room_id_2, room_id_3]) + _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("encryption", [room_id_1, room_id_2, room_id_3]) + _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("federatable", [room_id_1, room_id_2, room_id_3]) + _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("public", [room_id_1, room_id_2, room_id_3]) + # Different sort order of SQlite and PostreSQL + # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("join_rules", [room_id_1, room_id_2, room_id_3]) + _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("guest_access", [room_id_1, room_id_2, room_id_3]) + _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("history_visibility", [room_id_1, room_id_2, room_id_3]) + _order_test( + "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("state_events", [room_id_3, room_id_2, room_id_1]) + _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) + + def test_search_term(self): + """Test that searching for a room works correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + def _search_test( + expected_room_id: Optional[str], + search_term: str, + expected_http_code: int = 200, + ): + """Search for a room and check that the returned room's id is a match + + Args: + expected_room_id: The room_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for room names with + expected_http_code: The expected http code for the request + """ + url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that the expected number of rooms were returned + expected_room_count = 1 if expected_room_id else 0 + self.assertEqual(len(rooms), expected_room_count) + self.assertEqual(channel.json_body["total_rooms"], expected_room_count) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + if expected_room_id: + # Check that the first returned room id is correct + r = rooms[0] + self.assertEqual(expected_room_id, r["room_id"]) + + # Perform search tests + _search_test(room_id_1, "something") + _search_test(room_id_1, "thing") + + _search_test(room_id_2, "else") + _search_test(room_id_2, "se") + + _search_test(None, "foo") + _search_test(None, "bar") + _search_test(None, "", expected_http_code=400) + + def test_single_room(self): + """Test that a single room can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertIn("room_id", channel.json_body) + self.assertIn("name", channel.json_body) + self.assertIn("canonical_alias", channel.json_body) + self.assertIn("joined_members", channel.json_body) + self.assertIn("joined_local_members", channel.json_body) + self.assertIn("version", channel.json_body) + self.assertIn("creator", channel.json_body) + self.assertIn("encryption", channel.json_body) + self.assertIn("federatable", channel.json_body) + self.assertIn("public", channel.json_body) + self.assertIn("join_rules", channel.json_body) + self.assertIn("guest_access", channel.json_body) + self.assertIn("history_visibility", channel.json_body) + self.assertIn("state_events", channel.json_body) + + self.assertEqual(room_id_1, channel.json_body["room_id"]) + + def test_room_members(self): + """Test that room members can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Have another user join the room + user_1 = self.register_user("foo", "pass") + user_tok_1 = self.login("foo", "pass") + self.helper.join(room_id_1, user_1, tok=user_tok_1) + + # Have another user join the room + user_2 = self.register_user("bar", "pass") + user_tok_2 = self.login("bar", "pass") + self.helper.join(room_id_1, user_2, tok=user_tok_2) + self.helper.join(room_id_2, user_2, tok=user_tok_2) + + # Have another user join the room + user_3 = self.register_user("foobar", "pass") + user_tok_3 = self.login("foobar", "pass") + self.helper.join(room_id_2, user_3, tok=user_tok_3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + + +class JoinAliasRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.second_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If a parameter is missing, return an error + """ + body = json.dumps({"unknown_parameter": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_local_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + body = json.dumps({"user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_remote_user(self): + """ + Check that only local user can join rooms. + """ + body = json.dumps({"user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "This endpoint can only be used with local users", + channel.json_body["error"], + ) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/!unknown:test" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No known servers", channel.json_body["error"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/invalidroom" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom was not legal room ID or room alias", + channel.json_body["error"], + ) + + def test_join_public_room(self): + """ + Test joining a local user to a public room with "JoinRules.PUBLIC" + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_not_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE" + when server admin is not member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_join_private_room_if_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + self.helper.invite( + room=private_room_id, + src=self.creator, + targ=self.admin_user, + tok=self.creator_tok, + ) + self.helper.join( + room=private_room_id, user=self.admin_user, tok=self.admin_user_tok + ) + + # Validate if server admin is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + # Join user to room. + + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_owner(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is owner of this room. + """ + private_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) -- cgit 1.5.1 From 8144bc26a7432463b7e70f9c03198d4724952522 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 12:21:34 -0400 Subject: Convert push to async/await. (#7948) --- changelog.d/7948.misc | 1 + synapse/push/action_generator.py | 7 +-- synapse/push/bulk_push_rule_evaluator.py | 62 ++++++++----------- synapse/push/httppusher.py | 58 ++++++++---------- synapse/push/presentable_names.py | 15 ++--- synapse/push/push_tools.py | 22 +++---- synapse/push/pusherpool.py | 70 +++++++++------------- .../storage/data_stores/main/event_push_actions.py | 4 +- tests/replication/slave/storage/test_events.py | 6 +- tests/storage/test_event_push_actions.py | 6 +- 10 files changed, 106 insertions(+), 145 deletions(-) create mode 100644 changelog.d/7948.misc (limited to 'tests') diff --git a/changelog.d/7948.misc b/changelog.d/7948.misc new file mode 100644 index 0000000000..7c2e2b18b7 --- /dev/null +++ b/changelog.d/7948.misc @@ -0,0 +1 @@ +Convert push to async/await. diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 1ffd5e2df3..0d23142653 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.metrics import Measure from .bulk_push_rule_evaluator import BulkPushRuleEvaluator @@ -37,7 +35,6 @@ class ActionGenerator(object): # event stream, so we just run the rules for a client with no profile # tag (ie. we just need all the users). - @defer.inlineCallbacks - def handle_push_actions_for_event(self, event, context): + async def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "action_for_event_by_user"): - yield self.bulk_evaluator.action_for_event_by_user(event, context) + await self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 472ddf9f7d..04b9d8ac82 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,8 +19,6 @@ from collections import namedtuple from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.event_auth import get_user_power_level from synapse.state import POWER_KEY @@ -70,8 +68,7 @@ class BulkPushRuleEvaluator(object): resizable=False, ) - @defer.inlineCallbacks - def _get_rules_for_event(self, event, context): + async def _get_rules_for_event(self, event, context): """This gets the rules for all users in the room at the time of the event, as well as the push rules for the invitee if the event is an invite. @@ -79,19 +76,19 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = yield self._get_rules_for_room(room_id) + rules_for_room = await self._get_rules_for_room(room_id) - rules_by_user = yield rules_for_room.get_rules(event, context) + rules_by_user = await rules_for_room.get_rules(event, context) # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): - has_pusher = yield self.store.user_has_pusher(invited) + has_pusher = await self.store.user_has_pusher(invited) if has_pusher: rules_by_user = dict(rules_by_user) - rules_by_user[invited] = yield self.store.get_push_rules_for_user( + rules_by_user[invited] = await self.store.get_push_rules_for_user( invited ) @@ -114,20 +111,19 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics, ) - @defer.inlineCallbacks - def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids() + async def _get_power_levels_and_sender_level(self, event, context): + prev_state_ids = await context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and # not having a power level event is an extreme edge case - pl_event = yield self.store.get_event(pl_event_id) + pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: pl_event} else: - auth_events_ids = yield self.auth.compute_auth_events( + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -136,23 +132,19 @@ class BulkPushRuleEvaluator(object): return pl_event.content if pl_event else {}, sender_level - @defer.inlineCallbacks - def action_for_event_by_user(self, event, context): + async def action_for_event_by_user(self, event, context) -> None: """Given an event and context, evaluate the push rules and insert the results into the event_push_actions_staging table. - - Returns: - Deferred """ - rules_by_user = yield self._get_rules_for_event(event, context) + rules_by_user = await self._get_rules_for_event(event, context) actions_by_user = {} - room_members = yield self.store.get_joined_users_from_context(event, context) + room_members = await self.store.get_joined_users_from_context(event, context) ( power_levels, sender_power_level, - ) = yield self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels @@ -165,7 +157,7 @@ class BulkPushRuleEvaluator(object): continue if not event.is_state(): - is_ignored = yield self.store.is_ignored_by(event.sender, uid) + is_ignored = await self.store.is_ignored_by(event.sender, uid) if is_ignored: continue @@ -197,7 +189,7 @@ class BulkPushRuleEvaluator(object): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) + await self.store.add_push_actions_to_staging(event.event_id, actions_by_user) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -274,8 +266,7 @@ class RulesForRoom(object): # to self around in the callback. self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) - @defer.inlineCallbacks - def get_rules(self, event, context): + async def get_rules(self, event, context): """Given an event context return the rules for all users who are currently in the room. """ @@ -286,7 +277,7 @@ class RulesForRoom(object): self.room_push_rule_cache_metrics.inc_hits() return self.rules_by_user - with (yield self.linearizer.queue(())): + with (await self.linearizer.queue(())): if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() @@ -304,9 +295,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield defer.ensureDeferred( - context.get_current_state_ids() - ) + current_state_ids = await context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) @@ -353,7 +342,7 @@ class RulesForRoom(object): # If we have some memebr events we haven't seen, look them up # and fetch push rules for them if appropriate. logger.debug("Found new member events %r", missing_member_event_ids) - yield self._update_rules_with_member_event_ids( + await self._update_rules_with_member_event_ids( ret_rules_by_user, missing_member_event_ids, state_group, event ) else: @@ -371,8 +360,7 @@ class RulesForRoom(object): ) return ret_rules_by_user - @defer.inlineCallbacks - def _update_rules_with_member_event_ids( + async def _update_rules_with_member_event_ids( self, ret_rules_by_user, member_event_ids, state_group, event ): """Update the partially filled rules_by_user dict by fetching rules for @@ -388,7 +376,7 @@ class RulesForRoom(object): """ sequence = self.sequence - rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) + rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} @@ -410,7 +398,7 @@ class RulesForRoom(object): logger.debug("Joined: %r", interested_in_user_ids) - if_users_with_pushers = yield self.store.get_if_users_have_pushers( + if_users_with_pushers = await self.store.get_if_users_have_pushers( interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) @@ -420,7 +408,7 @@ class RulesForRoom(object): logger.debug("With pushers: %r", user_ids) - users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( + users_with_receipts = await self.store.get_users_with_read_receipts_in_room( self.room_id, on_invalidate=self.invalidate_all_cb ) @@ -431,7 +419,7 @@ class RulesForRoom(object): if uid in interested_in_user_ids: user_ids.add(uid) - rules_by_user = yield self.store.bulk_get_push_rules( + rules_by_user = await self.store.bulk_get_push_rules( user_ids, on_invalidate=self.invalidate_all_cb ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 2fac07593b..4c469efb20 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -17,7 +17,6 @@ import logging from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.api.constants import EventTypes @@ -128,12 +127,11 @@ class HttpPusher(object): # but currently that's the only type of receipt anyway... run_as_background_process("http_pusher.on_new_receipts", self._update_badge) - @defer.inlineCallbacks - def _update_badge(self): + async def _update_badge(self): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - yield self._send_badge(badge) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + await self._send_badge(badge) def on_timer(self): self._start_processing() @@ -152,8 +150,7 @@ class HttpPusher(object): run_as_background_process("httppush.process", self._process) - @defer.inlineCallbacks - def _process(self): + async def _process(self): # we should never get here if we are already processing assert not self._is_processing @@ -164,7 +161,7 @@ class HttpPusher(object): while True: starting_max_ordering = self.max_stream_ordering try: - yield self._unsafe_process() + await self._unsafe_process() except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: @@ -172,8 +169,7 @@ class HttpPusher(object): finally: self._is_processing = False - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): """ Looks for unset notifications and dispatch them, in order Never call this directly: use _process which will only allow this to @@ -181,7 +177,7 @@ class HttpPusher(object): """ fn = self.store.get_unread_push_actions_for_user_in_range_for_http - unprocessed = yield fn( + unprocessed = await fn( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) @@ -203,13 +199,13 @@ class HttpPusher(object): "app_display_name": self.app_display_name, }, ): - processed = yield self._process_one(push_action) + processed = await self._process_one(push_action) if processed: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( self.app_id, self.pushkey, self.user_id, @@ -224,14 +220,14 @@ class HttpPusher(object): if self.failing_since: self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) @@ -250,7 +246,7 @@ class HttpPusher(object): ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, self.user_id, @@ -263,7 +259,7 @@ class HttpPusher(object): return self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: @@ -276,18 +272,17 @@ class HttpPusher(object): ) break - @defer.inlineCallbacks - def _process_one(self, push_action): + async def _process_one(self, push_action): if "notify" not in push_action["actions"]: return True tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - event = yield self.store.get_event(push_action["event_id"], allow_none=True) + event = await self.store.get_event(push_action["event_id"], allow_none=True) if event is None: return True # It's been redacted - rejected = yield self.dispatch_push(event, tweaks, badge) + rejected = await self.dispatch_push(event, tweaks, badge) if rejected is False: return False @@ -301,11 +296,10 @@ class HttpPusher(object): ) else: logger.info("Pushkey %s was rejected: removing", pk) - yield self.hs.remove_pusher(self.app_id, pk, self.user_id) + await self.hs.remove_pusher(self.app_id, pk, self.user_id) return True - @defer.inlineCallbacks - def _build_notification_dict(self, event, tweaks, badge): + async def _build_notification_dict(self, event, tweaks, badge): priority = "low" if ( event.type == EventTypes.Encrypted @@ -335,7 +329,7 @@ class HttpPusher(object): } return d - ctx = yield push_tools.get_context_for_event( + ctx = await push_tools.get_context_for_event( self.storage, self.state_handler, event, self.user_id ) @@ -377,13 +371,12 @@ class HttpPusher(object): return d - @defer.inlineCallbacks - def dispatch_push(self, event, tweaks, badge): - notification_dict = yield self._build_notification_dict(event, tweaks, badge) + async def dispatch_push(self, event, tweaks, badge): + notification_dict = await self._build_notification_dict(event, tweaks, badge) if not notification_dict: return [] try: - resp = yield self.http_client.post_json_get_json( + resp = await self.http_client.post_json_get_json( self.url, notification_dict ) except Exception as e: @@ -400,8 +393,7 @@ class HttpPusher(object): rejected = resp["rejected"] return rejected - @defer.inlineCallbacks - def _send_badge(self, badge): + async def _send_badge(self, badge): """ Args: badge (int): number of unread messages @@ -424,7 +416,7 @@ class HttpPusher(object): } } try: - yield self.http_client.post_json_get_json(self.url, d) + await self.http_client.post_json_get_json(self.url, d) http_badges_processed_counter.inc() except Exception as e: logger.warning( diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 0644a13cfc..d8f4a453cd 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -16,8 +16,6 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes logger = logging.getLogger(__name__) @@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$") ALL_ALONE = "Empty Room" -@defer.inlineCallbacks -def calculate_room_name( +async def calculate_room_name( store, room_state_ids, user_id, @@ -53,7 +50,7 @@ def calculate_room_name( """ # does it have a name? if (EventTypes.Name, "") in room_state_ids: - m_room_name = yield store.get_event( + m_room_name = await store.get_event( room_state_ids[(EventTypes.Name, "")], allow_none=True ) if m_room_name and m_room_name.content and m_room_name.content["name"]: @@ -61,7 +58,7 @@ def calculate_room_name( # does it have a canonical alias? if (EventTypes.CanonicalAlias, "") in room_state_ids: - canon_alias = yield store.get_event( + canon_alias = await store.get_event( room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True ) if ( @@ -81,7 +78,7 @@ def calculate_room_name( my_member_event = None if (EventTypes.Member, user_id) in room_state_ids: - my_member_event = yield store.get_event( + my_member_event = await store.get_event( room_state_ids[(EventTypes.Member, user_id)], allow_none=True ) @@ -90,7 +87,7 @@ def calculate_room_name( and my_member_event.content["membership"] == "invite" ): if (EventTypes.Member, my_member_event.sender) in room_state_ids: - inviter_member_event = yield store.get_event( + inviter_member_event = await store.get_event( room_state_ids[(EventTypes.Member, my_member_event.sender)], allow_none=True, ) @@ -107,7 +104,7 @@ def calculate_room_name( # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. if EventTypes.Member in room_state_bytype_ids: - member_events = yield store.get_events( + member_events = await store.get_events( list(room_state_bytype_ids[EventTypes.Member].values()) ) all_members = [ diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 5dae4648c0..d0145666bf 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage -@defer.inlineCallbacks -def get_badge_count(store, user_id): - invites = yield store.get_invited_rooms_for_local_user(user_id) - joins = yield store.get_rooms_for_user(user_id) +async def get_badge_count(store, user_id): + invites = await store.get_invited_rooms_for_local_user(user_id) + joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") badge = len(invites) @@ -32,7 +29,7 @@ def get_badge_count(store, user_id): if room_id in my_receipts_by_room: last_unread_event_id = my_receipts_by_room[room_id] - notifs = yield ( + notifs = await ( store.get_unread_event_push_actions_by_room_for_user( room_id, user_id, last_unread_event_id ) @@ -43,23 +40,22 @@ def get_badge_count(store, user_id): return badge -@defer.inlineCallbacks -def get_context_for_event(storage: Storage, state_handler, ev, user_id): +async def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) + room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room - name = yield calculate_room_name( + name = await calculate_room_name( storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield storage.main.get_event(sender_state_event_id) + sender_state_event = await storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 2456f12f46..3c3262a88c 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Dict, Union from prometheus_client import Gauge -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher @@ -52,7 +50,7 @@ class PusherPool: Note that it is expected that each pusher will have its own 'processing' loop which will send out the notifications in the background, rather than blocking until the notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and - Pusher.on_new_receipts are not expected to return deferreds. + Pusher.on_new_receipts are not expected to return awaitables. """ def __init__(self, hs: "HomeServer"): @@ -77,8 +75,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - @defer.inlineCallbacks - def add_pusher( + async def add_pusher( self, user_id, access_token, @@ -94,7 +91,7 @@ class PusherPool: """Creates a new pusher and adds it to the pool Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ time_now_msec = self.clock.time_msec() @@ -124,9 +121,9 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() + last_stream_ordering = await self.store.get_latest_push_action_stream_ordering() - yield self.store.add_pusher( + await self.store.add_pusher( user_id=user_id, access_token=access_token, kind=kind, @@ -140,15 +137,14 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) - pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) return pusher - @defer.inlineCallbacks - def remove_pushers_by_app_id_and_pushkey_not_user( + async def remove_pushers_by_app_id_and_pushkey_not_user( self, app_id, pushkey, not_user_id ): - to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: if p["user_name"] != not_user_id: logger.info( @@ -157,10 +153,9 @@ class PusherPool: pushkey, p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def remove_pushers_by_access_token(self, user_id, access_tokens): + async def remove_pushers_by_access_token(self, user_id, access_tokens): """Remove the pushers for a given user corresponding to a set of access_tokens. @@ -173,7 +168,7 @@ class PusherPool: return tokens = set(access_tokens) - for p in (yield self.store.get_pushers_by_user_id(user_id)): + for p in await self.store.get_pushers_by_user_id(user_id): if p["access_token"] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", @@ -181,16 +176,15 @@ class PusherPool: p["pushkey"], p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def on_new_notifications(self, min_stream_id, max_stream_id): + async def on_new_notifications(self, min_stream_id, max_stream_id): if not self.pushers: # nothing to do here. return try: - users_affected = yield self.store.get_push_action_users_in_range( + users_affected = await self.store.get_push_action_users_in_range( min_stream_id, max_stream_id ) @@ -202,8 +196,7 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - @defer.inlineCallbacks - def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): if not self.pushers: # nothing to do here. return @@ -211,7 +204,7 @@ class PusherPool: try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive - users_affected = yield self.store.get_users_sent_receipts_between( + users_affected = await self.store.get_users_sent_receipts_between( min_stream_id - 1, max_stream_id ) @@ -223,12 +216,11 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - @defer.inlineCallbacks - def start_pusher_by_id(self, app_id, pushkey, user_id): + async def start_pusher_by_id(self, app_id, pushkey, user_id): """Look up the details for the given pusher, and start it Returns: - Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any + EmailPusher|HttpPusher|None: The pusher started, if any """ if not self._should_start_pushers: return @@ -236,7 +228,7 @@ class PusherPool: if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return - resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None for r in resultlist: @@ -245,34 +237,29 @@ class PusherPool: pusher = None if pusher_dict: - pusher = yield self._start_pusher(pusher_dict) + pusher = await self._start_pusher(pusher_dict) return pusher - @defer.inlineCallbacks - def _start_pushers(self): + async def _start_pushers(self) -> None: """Start all the pushers - - Returns: - Deferred """ - pushers = yield self.store.get_all_pushers() + pushers = await self.store.get_all_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. - yield concurrently_execute(self._start_pusher, pushers, 10) + await concurrently_execute(self._start_pusher, pushers, 10) logger.info("Started pushers") - @defer.inlineCallbacks - def _start_pusher(self, pusherdict): + async def _start_pusher(self, pusherdict): """Start the given pusher Args: pusherdict (dict): dict with the values pulled from the db table Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ if not self._pusher_shard_config.should_handle( self._instance_name, pusherdict["user_name"] @@ -315,7 +302,7 @@ class PusherPool: user_id = pusherdict["user_name"] last_stream_ordering = pusherdict["last_stream_ordering"] if last_stream_ordering: - have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( + have_notifs = await self.store.get_if_maybe_push_in_range_for_user( user_id, last_stream_ordering ) else: @@ -327,8 +314,7 @@ class PusherPool: return p - @defer.inlineCallbacks - def remove_pusher(self, app_id, pushkey, user_id): + async def remove_pusher(self, app_id, pushkey, user_id): appid_pushkey = "%s:%s" % (app_id, pushkey) byuser = self.pushers.get(user_id, {}) @@ -340,6 +326,6 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 504babaa7e..18297cf3b8 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -411,7 +411,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): _get_if_maybe_push_in_range_for_user_txn, ) - def add_push_actions_to_staging(self, event_id, user_id_actions): + async def add_push_actions_to_staging(self, event_id, user_id_actions): """Add the push actions for the event to the push action staging area. Args: @@ -457,7 +457,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ), ) - return self.db.runInteraction( + return await self.db.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 1a88c7fb80..0b5204654c 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_handler = self.hs.get_state_handler() context = self.get_success(state_handler.compute_event_context(event)) - self.master_store.add_push_actions_to_staging( - event.event_id, {user_id: actions for user_id, actions in push_actions} + self.get_success( + self.master_store.add_push_actions_to_staging( + event.event_id, {user_id: actions for user_id, actions in push_actions} + ) ) return event, context diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b45bc9c115..43dbeb42c5 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -72,8 +72,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.internal_metadata.stream_ordering = stream event.depth = stream - yield self.store.add_push_actions_to_staging( - event.event_id, {user_id: action} + yield defer.ensureDeferred( + self.store.add_push_actions_to_staging( + event.event_id, {user_id: action} + ) ) yield self.store.db.runInteraction( "", -- cgit 1.5.1 From 5f65e6268146a5ae7b8dafdfe2290b791e8b4c92 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 12:32:08 -0400 Subject: Convert groups and visibility code to async / await. (#7951) --- changelog.d/7951.misc | 1 + synapse/groups/attestations.py | 25 +++++++++++-------------- synapse/visibility.py | 30 +++++++++++++----------------- tests/test_visibility.py | 12 ++++++------ 4 files changed, 31 insertions(+), 37 deletions(-) create mode 100644 changelog.d/7951.misc (limited to 'tests') diff --git a/changelog.d/7951.misc b/changelog.d/7951.misc new file mode 100644 index 0000000000..cbba4fa826 --- /dev/null +++ b/changelog.d/7951.misc @@ -0,0 +1 @@ +Convert groups and visibility code to async / await. diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index dab13c243f..e674bf44a2 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -41,8 +41,6 @@ from typing import Tuple from signedjson.sign import sign_json -from twisted.internet import defer - from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id @@ -72,8 +70,9 @@ class GroupAttestationSigning(object): self.server_name = hs.hostname self.signing_key = hs.signing_key - @defer.inlineCallbacks - def verify_attestation(self, attestation, group_id, user_id, server_name=None): + async def verify_attestation( + self, attestation, group_id, user_id, server_name=None + ): """Verifies that the given attestation matches the given parameters. An optional server_name can be supplied to explicitly set which server's @@ -102,7 +101,7 @@ class GroupAttestationSigning(object): if valid_until_ms < now: raise SynapseError(400, "Attestation expired") - yield self.keyring.verify_json_for_server( + await self.keyring.verify_json_for_server( server_name, attestation, now, "Group attestation" ) @@ -142,8 +141,7 @@ class GroupAttestionRenewer(object): self._start_renew_attestations, 30 * 60 * 1000 ) - @defer.inlineCallbacks - def on_renew_attestation(self, group_id, user_id, content): + async def on_renew_attestation(self, group_id, user_id, content): """When a remote updates an attestation """ attestation = content["attestation"] @@ -151,11 +149,11 @@ class GroupAttestionRenewer(object): if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): raise SynapseError(400, "Neither user not group are on this server") - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( attestation, user_id=user_id, group_id=group_id ) - yield self.store.update_remote_attestion(group_id, user_id, attestation) + await self.store.update_remote_attestion(group_id, user_id, attestation) return {} @@ -172,8 +170,7 @@ class GroupAttestionRenewer(object): now + UPDATE_ATTESTATION_TIME_MS ) - @defer.inlineCallbacks - def _renew_attestation(group_user: Tuple[str, str]): + async def _renew_attestation(group_user: Tuple[str, str]): group_id, user_id = group_user try: if not self.is_mine_id(group_id): @@ -186,16 +183,16 @@ class GroupAttestionRenewer(object): user_id, group_id, ) - yield self.store.remove_attestation_renewal(group_id, user_id) + await self.store.remove_attestation_renewal(group_id, user_id) return attestation = self.attestations.create_attestation(group_id, user_id) - yield self.transport_client.renew_group_attestation( + await self.transport_client.renew_group_attestation( destination, group_id, user_id, content={"attestation": attestation} ) - yield self.store.update_attestation_renewal( + await self.store.update_attestation_renewal( group_id, user_id, attestation ) except (RequestSendFailed, HttpResponseException) as e: diff --git a/synapse/visibility.py b/synapse/visibility.py index 0f042c5696..e3da7744d2 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -16,8 +16,6 @@ import logging import operator -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event from synapse.storage import Storage @@ -39,8 +37,7 @@ MEMBERSHIP_PRIORITY = ( ) -@defer.inlineCallbacks -def filter_events_for_client( +async def filter_events_for_client( storage: Storage, user_id, events, @@ -67,19 +64,19 @@ def filter_events_for_client( also be called to check whether a user can see the state at a given point. Returns: - Deferred[list[synapse.events.EventBase]] + list[synapse.events.EventBase] """ # Filter out events that have been soft failed so that we don't relay them # to clients. events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield storage.state.get_state_for_events( + event_id_to_state = await storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( + ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -90,7 +87,7 @@ def filter_events_for_client( else [] ) - erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased((e.sender for e in events)) if filter_send_to_client: room_ids = {e.room_id for e in events} @@ -99,7 +96,7 @@ def filter_events_for_client( for room_id in room_ids: retention_policies[ room_id - ] = yield storage.main.get_retention_policy_for_room(room_id) + ] = await storage.main.get_retention_policy_for_room(room_id) def allowed(event): """ @@ -254,8 +251,7 @@ def filter_events_for_client( return list(filtered_events) -@defer.inlineCallbacks -def filter_events_for_server( +async def filter_events_for_server( storage: Storage, server_name, events, @@ -277,7 +273,7 @@ def filter_events_for_server( backfill or not. Returns - Deferred[list[FrozenEvent]] + list[FrozenEvent] """ def is_sender_erased(event, erased_senders): @@ -321,7 +317,7 @@ def filter_events_for_server( # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield storage.state.get_state_ids_for_events( + event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -339,14 +335,14 @@ def filter_events_for_server( if not visibility_ids: all_open = True else: - event_map = yield storage.main.get_events(visibility_ids) + event_map = await storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") for e in event_map.values() ) if not check_history_visibility_only: - erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -375,7 +371,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield storage.state.get_state_ids_for_events( + event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -405,7 +401,7 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield storage.main.get_events( + event_map = await storage.main.get_events( [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])] ) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index b371efc0df..a7a36174ea 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): evt = yield self.inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") # ... and the filtering happens. - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) for i in range(0, len(events_to_filter)): @@ -265,8 +265,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): storage.main = test_store storage.state = test_store - filtered = yield filter_events_for_server( - test_store, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(test_store, "test_server", events_to_filter) ) logger.info("Filtering took %f seconds", time.time() - start) -- cgit 1.5.1 From 8553f4649857c7862e30917adc925642ad684a10 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 13:40:22 -0400 Subject: Convert a synapse.events to async/await. (#7949) --- changelog.d/7948.misc | 2 +- changelog.d/7949.misc | 1 + changelog.d/7951.misc | 2 +- synapse/api/auth.py | 2 +- synapse/events/builder.py | 19 +++++------- synapse/events/snapshot.py | 46 ++++++++++++++-------------- synapse/events/third_party_rules.py | 55 ++++++++++++++++++---------------- synapse/events/utils.py | 15 +++++----- synapse/handlers/federation.py | 2 +- synapse/replication/http/federation.py | 4 ++- synapse/replication/http/send_event.py | 2 +- tests/storage/test_redaction.py | 4 ++- tests/test_state.py | 14 ++++----- 13 files changed, 86 insertions(+), 82 deletions(-) create mode 100644 changelog.d/7949.misc (limited to 'tests') diff --git a/changelog.d/7948.misc b/changelog.d/7948.misc index 7c2e2b18b7..dfe4c03171 100644 --- a/changelog.d/7948.misc +++ b/changelog.d/7948.misc @@ -1 +1 @@ -Convert push to async/await. +Convert various parts of the codebase to async/await. diff --git a/changelog.d/7949.misc b/changelog.d/7949.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/7949.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/changelog.d/7951.misc b/changelog.d/7951.misc index cbba4fa826..dfe4c03171 100644 --- a/changelog.d/7951.misc +++ b/changelog.d/7951.misc @@ -1 +1 @@ -Convert groups and visibility code to async / await. +Convert various parts of the codebase to async/await. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b53e8451e5..2178e623da 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -82,7 +82,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, room_version: str, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 0bb216419a..69b53ca2bc 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -17,8 +17,6 @@ from typing import Optional import attr from nacl.signing import SigningKey -from twisted.internet import defer - from synapse.api.constants import MAX_DEPTH from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( @@ -95,31 +93,30 @@ class EventBuilder(object): def is_state(self): return self._state_key is not None - @defer.inlineCallbacks - def build(self, prev_event_ids): + async def build(self, prev_event_ids): """Transform into a fully signed and hashed event Args: prev_event_ids (list[str]): The event IDs to use as the prev events Returns: - Deferred[FrozenEvent] + FrozenEvent """ - state_ids = yield defer.ensureDeferred( - self._state.get_current_state_ids(self.room_id, prev_event_ids) + state_ids = await self._state.get_current_state_ids( + self.room_id, prev_event_ids ) - auth_ids = yield self._auth.compute_auth_events(self, state_ids) + auth_ids = await self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: - auth_events = yield self._store.add_event_hashes(auth_ids) - prev_events = yield self._store.add_event_hashes(prev_event_ids) + auth_events = await self._store.add_event_hashes(auth_ids) + prev_events = await self._store.add_event_hashes(prev_event_ids) else: auth_events = auth_ids prev_events = prev_event_ids - old_depth = yield self._store.get_max_depth_of(prev_event_ids) + old_depth = await self._store.get_max_depth_of(prev_event_ids) depth = old_depth + 1 # we cap depth of generated events, to ensure that they are not diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f94cdcbaba..cca93e3a46 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,17 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import attr from frozendict import frozendict -from twisted.internet import defer - from synapse.appservice import ApplicationService +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import StateMap +if TYPE_CHECKING: + from synapse.storage.data_stores.main import DataStore + @attr.s(slots=True) class EventContext: @@ -129,8 +131,7 @@ class EventContext: delta_ids=delta_ids, ) - @defer.inlineCallbacks - def serialize(self, event, store): + async def serialize(self, event: EventBase, store: "DataStore") -> dict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -146,7 +147,7 @@ class EventContext: # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids() + prev_state_ids = await self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -214,8 +215,7 @@ class EventContext: return self._state_group - @defer.inlineCallbacks - def get_current_state_ids(self): + async def get_current_state_ids(self) -> Optional[StateMap[str]]: """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -224,32 +224,31 @@ class EventContext: ``rejected`` is set. Returns: - Deferred[dict[(str, str), str]|None]: Returns None if state_group - is None, which happens when the associated event is an outlier. + Returns None if state_group is None, which happens when the associated + event is an outlier. - Maps a (type, state_key) to the event ID of the state event matching - this tuple. + Maps a (type, state_key) to the event ID of the state event matching + this tuple. """ if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched() + await self._ensure_fetched() return self._current_state_ids - @defer.inlineCallbacks - def get_prev_state_ids(self): + async def get_prev_state_ids(self): """ Gets the room state map, excluding this event. For a non-state event, this will be the same as get_current_state_ids(). Returns: - Deferred[dict[(str, str), str]|None]: Returns None if state_group + dict[(str, str), str]|None: Returns None if state_group is None, which happens when the associated event is an outlier. Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched() + await self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): @@ -269,8 +268,8 @@ class EventContext: return self._current_state_ids - def _ensure_fetched(self): - return defer.succeed(None) + async def _ensure_fetched(self): + return None @attr.s(slots=True) @@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext): _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self): + async def _ensure_fetched(self): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background(self._fill_out_state) - return make_deferred_yieldable(self._fetching_state_deferred) + return await make_deferred_yieldable(self._fetching_state_deferred) - @defer.inlineCallbacks - def _fill_out_state(self): + async def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self._current_state_ids = await self._storage.state.get_state_ids_for_group( self.state_group ) if self._event_state_key is not None: diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 459132d388..2956a64234 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.types import Requester class ThirdPartyEventRules(object): @@ -39,76 +41,79 @@ class ThirdPartyEventRules(object): config=config, http_client=hs.get_simple_http_client() ) - @defer.inlineCallbacks - def check_event_allowed(self, event, context): + async def check_event_allowed( + self, event: EventBase, context: EventContext + ) -> bool: """Check if a provided event should be allowed in the given context. Args: - event (synapse.events.EventBase): The event to be checked. - context (synapse.events.snapshot.EventContext): The context of the event. + event: The event to be checked. + context: The context of the event. Returns: - defer.Deferred[bool]: True if the event should be allowed, False if not. + True if the event should be allowed, False if not. """ if self.third_party_rules is None: return True - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() # Retrieve the state events from the database. state_events = {} for key, event_id in prev_state_ids.items(): - state_events[key] = yield self.store.get_event(event_id, allow_none=True) + state_events[key] = await self.store.get_event(event_id, allow_none=True) - ret = yield self.third_party_rules.check_event_allowed(event, state_events) + ret = await self.third_party_rules.check_event_allowed(event, state_events) return ret - @defer.inlineCallbacks - def on_create_room(self, requester, config, is_requester_admin): + async def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ) -> bool: """Intercept requests to create room to allow, deny or update the request config. Args: - requester (Requester) - config (dict): The creation config from the client. - is_requester_admin (bool): If the requester is an admin + requester + config: The creation config from the client. + is_requester_admin: If the requester is an admin Returns: - defer.Deferred[bool]: Whether room creation is allowed or denied. + Whether room creation is allowed or denied. """ if self.third_party_rules is None: return True - ret = yield self.third_party_rules.on_create_room( + ret = await self.third_party_rules.on_create_room( requester, config, is_requester_admin ) return ret - @defer.inlineCallbacks - def check_threepid_can_be_invited(self, medium, address, room_id): + async def check_threepid_can_be_invited( + self, medium: str, address: str, room_id: str + ) -> bool: """Check if a provided 3PID can be invited in the given room. Args: - medium (str): The 3PID's medium. - address (str): The 3PID's address. - room_id (str): The room we want to invite the threepid to. + medium: The 3PID's medium. + address: The 3PID's address. + room_id: The room we want to invite the threepid to. Returns: - defer.Deferred[bool], True if the 3PID can be invited, False if not. + True if the 3PID can be invited, False if not. """ if self.third_party_rules is None: return True - state_ids = yield self.store.get_filtered_current_state_ids(room_id) - room_state_events = yield self.store.get_events(state_ids.values()) + state_ids = await self.store.get_filtered_current_state_ids(room_id) + room_state_events = await self.store.get_events(state_ids.values()) state_events = {} for key, event_id in state_ids.items(): state_events[key] = room_state_events[event_id] - ret = yield self.third_party_rules.check_threepid_can_be_invited( + ret = await self.third_party_rules.check_threepid_can_be_invited( medium, address, state_events ) return ret diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 11f0d34ec8..2d42e268c6 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -18,8 +18,6 @@ from typing import Any, Mapping, Union from frozendict import frozendict -from twisted.internet import defer - from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion @@ -337,8 +335,9 @@ class EventClientSerializer(object): hs.config.experimental_msc1849_support_enabled ) - @defer.inlineCallbacks - def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs): + async def serialize_event( + self, event, time_now, bundle_aggregations=True, **kwargs + ): """Serializes a single event. Args: @@ -348,7 +347,7 @@ class EventClientSerializer(object): **kwargs: Arguments to pass to `serialize_event` Returns: - Deferred[dict]: The serialized event + dict: The serialized event """ # To handle the case of presence events and the like if not isinstance(event, EventBase): @@ -363,8 +362,8 @@ class EventClientSerializer(object): if not event.internal_metadata.is_redacted() and ( self.experimental_msc1849_support_enabled and bundle_aggregations ): - annotations = yield self.store.get_aggregation_groups_for_event(event_id) - references = yield self.store.get_relations_for_event( + annotations = await self.store.get_aggregation_groups_for_event(event_id) + references = await self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" ) @@ -378,7 +377,7 @@ class EventClientSerializer(object): edit = None if event.type == EventTypes.Message: - edit = yield self.store.get_applicable_edit(event_id) + edit = await self.store.get_applicable_edit(event_id) if edit: # If there is an edit replace the content, preserving existing diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f5f683bfd4..0d7d1adcea 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2470,7 +2470,7 @@ class FederationHandler(BaseHandler): } current_state_ids = await context.get_current_state_ids() - current_state_ids = dict(current_state_ids) + current_state_ids = dict(current_state_ids) # type: ignore current_state_ids.update(state_updates) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index c287c4e269..ca065e819e 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -78,7 +78,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """ event_payloads = [] for event, context in event_and_contexts: - serialized_context = yield context.serialize(event, store) + serialized_context = yield defer.ensureDeferred( + context.serialize(event, store) + ) event_payloads.append( { diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index c981723c1a..b30e4d5039 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -77,7 +77,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): extra_users (list(UserID)): Any extra users to notify about event """ - serialized_context = yield context.serialize(event, store) + serialized_context = yield defer.ensureDeferred(context.serialize(event, store)) payload = { "event": event.get_pdu_json(), diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index db3667dc43..0f0e1cd09b 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def build(self, prev_event_ids): - built_event = yield self._base_builder.build(prev_event_ids) + built_event = yield defer.ensureDeferred( + self._base_builder.build(prev_event_ids) + ) built_event._event_id = self._event_id built_event._dict["event_id"] = self._event_id diff --git a/tests/test_state.py b/tests/test_state.py index 4858e8fc59..b5c3667d2a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -213,7 +213,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertEqual(2, len(prev_state_ids)) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -259,7 +259,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -318,7 +318,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_e = context_store["E"] - prev_state_ids = yield ctx_e.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids()) self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @@ -393,7 +393,7 @@ class StateTestCase(unittest.TestCase): ctx_b = context_store["B"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) @@ -425,7 +425,7 @@ class StateTestCase(unittest.TestCase): self.state.compute_event_context(event, old_state=old_state) ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) @@ -450,7 +450,7 @@ class StateTestCase(unittest.TestCase): self.state.compute_event_context(event, old_state=old_state) ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) @@ -519,7 +519,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) -- cgit 1.5.1 From 3857de2194e3b2057c4af71e095eb6759508f25f Mon Sep 17 00:00:00 2001 From: lugino-emeritus Date: Tue, 28 Jul 2020 14:41:44 +0200 Subject: Option to allow server admins to join complex rooms (#7902) Fixes #7901. Signed-off-by: Niklas Tittjung --- changelog.d/7902.feature | 1 + docs/sample_config.yaml | 4 ++ synapse/config/server.py | 7 +++ synapse/handlers/room_member.py | 8 ++- tests/federation/test_complexity.py | 109 ++++++++++++++++++++++++++++++++++++ 5 files changed, 127 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7902.feature (limited to 'tests') diff --git a/changelog.d/7902.feature b/changelog.d/7902.feature new file mode 100644 index 0000000000..4feae8cc29 --- /dev/null +++ b/changelog.d/7902.feature @@ -0,0 +1 @@ +Add option to allow server admins to join rooms which fail complexity checks. Contributed by @lugino-emeritus. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 3227294e0b..09a7299871 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -314,6 +314,10 @@ limit_remote_rooms: # #complexity_error: "This room is too complex." + # allow server admins to join complex rooms. Default is false. + # + #admins_can_join: true + # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. # diff --git a/synapse/config/server.py b/synapse/config/server.py index 3747a01ca7..848587d232 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -439,6 +439,9 @@ class ServerConfig(Config): validator=attr.validators.instance_of(str), default=ROOM_COMPLEXITY_TOO_GREAT, ) + admins_can_join = attr.ib( + validator=attr.validators.instance_of(bool), default=False + ) self.limit_remote_rooms = LimitRemoteRoomsConfig( **(config.get("limit_remote_rooms") or {}) @@ -893,6 +896,10 @@ class ServerConfig(Config): # #complexity_error: "This room is too complex." + # allow server admins to join complex rooms. Default is false. + # + #admins_can_join: true + # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. # diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a1a8fa1d3b..5a40e8c144 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -952,7 +952,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - if self.hs.config.limit_remote_rooms.enabled: + check_complexity = self.hs.config.limit_remote_rooms.enabled + if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: + check_complexity = not await self.hs.auth.is_server_admin(user) + + if check_complexity: # Fetch the room complexity too_complex = await self._is_remote_room_too_complex( room_id, remote_room_hosts @@ -975,7 +979,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. - if self.hs.config.limit_remote_rooms.enabled: + if check_complexity: if too_complex is False: # We checked, and we're under the limit. return event_id, stream_id diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 0c9987be54..5cd0510f0d 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -99,6 +99,37 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertEqual(f.value.code, 400, f.value) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_join_too_large_admin(self): + # Check whether an admin can join if option "admins_can_join" is undefined, + # this option defaults to false, so the join should fail. + + u1 = self.register_user("u1", "pass", admin=True) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_join_too_large_once_joined(self): u1 = self.register_user("u1", "pass") @@ -141,3 +172,81 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): f = self.get_failure(d, SynapseError) self.assertEqual(f.value.code, 400) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + +class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): + # Test the behavior of joining rooms which exceed the complexity if option + # limit_remote_rooms.admins_can_join is True. + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["limit_remote_rooms"] = { + "enabled": True, + "complexity": 0.05, + "admins_can_join": True, + } + return config + + def test_join_too_large_no_admin(self): + # A user which is not an admin should not be able to join a remote room + # which is too complex. + + u1 = self.register_user("u1", "pass") + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_join_too_large_admin(self): + # An admin should be able to join rooms where a complexity check fails. + + u1 = self.register_user("u1", "pass", admin=True) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request success since the user is an admin + self.get_success(d) -- cgit 1.5.1 From e866e3b8966efc470038b48061a89aac513eb6e0 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 28 Jul 2020 21:08:23 +0200 Subject: Add an option to disable purge in delete room admin API (#7964) Add option ```purge``` to ```POST /_synapse/admin/v1/rooms//delete``` Fixes: #3761 Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/7964.feature | 1 + docs/admin_api/rooms.md | 13 +++++++--- synapse/rest/admin/rooms.py | 11 ++++++++- tests/rest/admin/test_room.py | 57 +++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 75 insertions(+), 7 deletions(-) create mode 100644 changelog.d/7964.feature (limited to 'tests') diff --git a/changelog.d/7964.feature b/changelog.d/7964.feature new file mode 100644 index 0000000000..ffe861650c --- /dev/null +++ b/changelog.d/7964.feature @@ -0,0 +1 @@ +Add an option to purge room or not with delete room admin endpoint (`POST /_synapse/admin/v1/rooms//delete`). Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 15b83e9824..0f267d2b7b 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -369,7 +369,9 @@ to the new room will have power level `-10` by default, and thus be unable to sp If `block` is `True` it prevents new joins to the old room. This API will remove all trace of the old room from your database after removing -all local users. +all local users. If `purge` is `true` (the default), all traces of the old room will +be removed from your database after removing all local users. If you do not want +this to happen, set `purge` to `false`. Depending on the amount of history being purged a call to the API may take several minutes or longer. @@ -388,7 +390,8 @@ with a body of: "new_room_user_id": "@someuser:example.com", "room_name": "Content Violation Notification", "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.", - "block": true + "block": true, + "purge": true } ``` @@ -430,8 +433,10 @@ The following JSON body parameters are available: `new_room_user_id` in the new room. Ideally this will clearly convey why the original room was shut down. Defaults to `Sharing illegal content on this server is not permitted and rooms in violation will be blocked.` -* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to - join the room. Defaults to `false`. +* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing + future attempts to join the room. Defaults to `false`. +* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database. + Defaults to `true`. The JSON body must not be empty. The body must be at least `{}`. diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index b8c95d045a..a8364d9793 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -103,6 +103,14 @@ class DeleteRoomRestServlet(RestServlet): Codes.BAD_JSON, ) + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + ret = await self.room_shutdown_handler.shutdown_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), @@ -113,7 +121,8 @@ class DeleteRoomRestServlet(RestServlet): ) # Purge room - await self.pagination_handler.purge_room(room_id) + if purge: + await self.pagination_handler.purge_room(room_id) return (200, ret) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index ba8552c29f..cec1cf928f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -283,6 +283,23 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + def test_purge_is_not_bool(self): + """ + If parameter `purge` is not boolean, return an error + """ + body = json.dumps({"purge": "NotBool"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + def test_purge_room_and_block(self): """Test to purge a room and block it. Members will not be moved to a new room and will not receive a message. @@ -297,7 +314,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": True}) + body = json.dumps({"block": True, "purge": True}) request, channel = self.make_request( "POST", @@ -331,7 +348,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": False}) + body = json.dumps({"block": False, "purge": True}) request, channel = self.make_request( "POST", @@ -351,6 +368,42 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._is_blocked(self.room_id, expect=False) self._has_no_members(self.room_id) + def test_block_room_and_not_purge(self): + """Test to block a room without purging it. + Members will not be moved to a new room and will not receive a message. + The room will not be purged. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": False, "purge": False}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not yet accepted the privacy policy. This used to fail when we tried to -- cgit 1.5.1 From 3345c166a45cb4a8f87c583ee0476c2bca5c41bd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 28 Jul 2020 16:09:53 -0400 Subject: Convert storage layer to async/await. (#7963) --- changelog.d/7963.misc | 1 + synapse/storage/persist_events.py | 40 ++++---- synapse/storage/purge_events.py | 38 ++++--- synapse/storage/state.py | 207 ++++++++++++++++++++------------------ tests/storage/test_purge.py | 8 +- tests/storage/test_room.py | 6 +- tests/storage/test_state.py | 64 +++++++----- tests/test_visibility.py | 14 ++- tests/utils.py | 16 +-- tox.ini | 1 + 10 files changed, 210 insertions(+), 185 deletions(-) create mode 100644 changelog.d/7963.misc (limited to 'tests') diff --git a/changelog.d/7963.misc b/changelog.d/7963.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/7963.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 78fbdcdee8..4a164834d9 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import FrozenEvent +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process @@ -192,12 +192,11 @@ class EventsPersistenceStorage(object): self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def persist_events( + async def persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ): + ) -> int: """ Write events to the database Args: @@ -207,7 +206,7 @@ class EventsPersistenceStorage(object): which might update the current state etc. Returns: - Deferred[int]: the stream ordering of the latest persisted event + the stream ordering of the latest persisted event """ partitioned = {} for event, ctx in events_and_contexts: @@ -223,22 +222,19 @@ class EventsPersistenceStorage(object): for room_id in partitioned: self._maybe_start_persisting(room_id) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) - max_persisted_id = yield self.main_store.get_current_events_token() - - return max_persisted_id + return self.main_store.get_current_events_token() - @defer.inlineCallbacks - def persist_event( - self, event: FrozenEvent, context: EventContext, backfilled: bool = False - ): + async def persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ) -> Tuple[int, int]: """ Returns: - Deferred[Tuple[int, int]]: the stream ordering of ``event``, - and the stream ordering of the latest persisted event + The stream ordering of `event`, and the stream ordering of the + latest persisted event """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled @@ -246,9 +242,9 @@ class EventsPersistenceStorage(object): self._maybe_start_persisting(event.room_id) - yield make_deferred_yieldable(deferred) + await make_deferred_yieldable(deferred) - max_persisted_id = yield self.main_store.get_current_events_token() + max_persisted_id = self.main_store.get_current_events_token() return (event.internal_metadata.stream_ordering, max_persisted_id) def _maybe_start_persisting(self, room_id: str): @@ -262,7 +258,7 @@ class EventsPersistenceStorage(object): async def _persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, ): """Calculates the change to current state and forward extremities, and @@ -439,7 +435,7 @@ class EventsPersistenceStorage(object): async def _calculate_new_extremities( self, room_id: str, - event_contexts: List[Tuple[FrozenEvent, EventContext]], + event_contexts: List[Tuple[EventBase, EventContext]], latest_event_ids: List[str], ): """Calculates the new forward extremities for a room given events to @@ -497,7 +493,7 @@ class EventsPersistenceStorage(object): async def _get_new_state_after_events( self, room_id: str, - events_context: List[Tuple[FrozenEvent, EventContext]], + events_context: List[Tuple[EventBase, EventContext]], old_latest_event_ids: Iterable[str], new_latest_event_ids: Iterable[str], ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: @@ -683,7 +679,7 @@ class EventsPersistenceStorage(object): async def _is_server_still_joined( self, room_id: str, - ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], + ev_ctx_rm: List[Tuple[EventBase, EventContext]], delta: DeltaState, current_state: Optional[StateMap[str]], potentially_left_users: Set[str], diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index fdc0abf5cf..79d9f06e2e 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -15,8 +15,7 @@ import itertools import logging - -from twisted.internet import defer +from typing import Set logger = logging.getLogger(__name__) @@ -28,49 +27,48 @@ class PurgeEventsStorage(object): def __init__(self, hs, stores): self.stores = stores - @defer.inlineCallbacks - def purge_room(self, room_id: str): + async def purge_room(self, room_id: str): """Deletes all record of a room """ - state_groups_to_delete = yield self.stores.main.purge_room(room_id) - yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) + state_groups_to_delete = await self.stores.main.purge_room(room_id) + await self.stores.state.purge_room_state(room_id, state_groups_to_delete) - @defer.inlineCallbacks - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> None: """Deletes room history before a certain point Args: - room_id (str): + room_id: The room ID - token (str): A topological token to delete events before + token: A topological token to delete events before - delete_local_events (bool): + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). """ - state_groups = yield self.stores.main.purge_history( + state_groups = await self.stores.main.purge_history( room_id, token, delete_local_events ) logger.info("[purge] finding state groups that can be deleted") - sg_to_delete = yield self._find_unreferenced_groups(state_groups) + sg_to_delete = await self._find_unreferenced_groups(state_groups) - yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) + await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) - @defer.inlineCallbacks - def _find_unreferenced_groups(self, state_groups): + async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: """Used when purging history to figure out which state groups can be deleted. Args: - state_groups (set[int]): Set of state groups referenced by events + state_groups: Set of state groups referenced by events that are going to be deleted. Returns: - Deferred[set[int]] The set of state groups that can be deleted. + The set of state groups that can be deleted. """ # Graph of state group -> previous group graph = {} @@ -93,7 +91,7 @@ class PurgeEventsStorage(object): current_search = set(itertools.islice(next_to_search, 100)) next_to_search -= current_search - referenced = yield self.stores.main.get_referenced_state_groups( + referenced = await self.stores.main.get_referenced_state_groups( current_search ) referenced_groups |= referenced @@ -102,7 +100,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.state.get_previous_state_groups(current_search) + edges = await self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dc568476f4..49ee9c9a74 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,13 +14,12 @@ # limitations under the License. import logging -from typing import Iterable, List, TypeVar +from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr -from twisted.internet import defer - from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -34,16 +33,16 @@ class StateFilter(object): """A filter used when querying for state. Attributes: - types (dict[str, set[str]|None]): Map from type to set of state keys (or - None). This specifies which state_keys for the given type to fetch - from the DB. If None then all events with that type are fetched. If - the set is empty then no events with that type are fetched. - include_others (bool): Whether to fetch events with types that do not + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not appear in `types`. """ - types = attr.ib() - include_others = attr.ib(default=False) + types = attr.ib(type=Dict[str, Optional[Set[str]]]) + include_others = attr.ib(default=False, type=bool) def __attrs_post_init__(self): # If `include_others` is set we canonicalise the filter by removing @@ -52,36 +51,35 @@ class StateFilter(object): self.types = {k: v for k, v in self.types.items() if v is not None} @staticmethod - def all(): + def all() -> "StateFilter": """Creates a filter that fetches everything. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=True) @staticmethod - def none(): + def none() -> "StateFilter": """Creates a filter that fetches nothing. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=False) @staticmethod - def from_types(types): + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": """Creates a filter that only fetches the given types Args: - types (Iterable[tuple[str, str|None]]): A list of type and state - keys to fetch. A state_key of None fetches everything for - that type + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type Returns: - StateFilter + The new state filter. """ - type_dict = {} + type_dict = {} # type: Dict[str, Optional[Set[str]]] for typ, s in types: if typ in type_dict: if type_dict[typ] is None: @@ -91,24 +89,24 @@ class StateFilter(object): type_dict[typ] = None continue - type_dict.setdefault(typ, set()).add(s) + type_dict.setdefault(typ, set()).add(s) # type: ignore return StateFilter(types=type_dict) @staticmethod - def from_lazy_load_member_list(members): + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member events for the given users Args: - members (iterable[str]): Set of user IDs + members: Set of user IDs Returns: - StateFilter + The new state filter """ return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) - def return_expanded(self): + def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed (except for memberships). The returned filter is a superset of the current one, i.e. anything that passes the current filter will pass @@ -130,7 +128,7 @@ class StateFilter(object): return all non-member events Returns: - StateFilter + The new state filter. """ if self.is_full(): @@ -167,7 +165,7 @@ class StateFilter(object): include_others=True, ) - def make_sql_filter_clause(self): + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. For example: @@ -179,13 +177,12 @@ class StateFilter(object): Returns: - tuple[str, list]: The SQL string (may be empty) and arguments. An - empty SQL string is returned when the filter matches everything - (i.e. is "full"). + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). """ where_clause = "" - where_args = [] + where_args = [] # type: List[str] if self.is_full(): return where_clause, where_args @@ -221,7 +218,7 @@ class StateFilter(object): return where_clause, where_args - def max_entries_returned(self): + def max_entries_returned(self) -> Optional[int]: """Returns the maximum number of entries this filter will return if known, otherwise returns None. @@ -260,33 +257,33 @@ class StateFilter(object): return filtered_state - def is_full(self): + def is_full(self) -> bool: """Whether this filter fetches everything or not Returns: - bool + True if the filter fetches everything. """ return self.include_others and not self.types - def has_wildcards(self): + def has_wildcards(self) -> bool: """Whether the filter includes wildcards or is attempting to fetch specific state. Returns: - bool + True if the filter includes wildcards. """ return self.include_others or any( state_keys is None for state_keys in self.types.values() ) - def concrete_types(self): + def concrete_types(self) -> List[Tuple[str, str]]: """Returns a list of concrete type/state_keys (i.e. not None) that will be fetched. This will be a complete list if `has_wildcards` returns False, but otherwise will be a subset (or even empty). Returns: - list[tuple[str,str]] + A list of type/state_keys tuples. """ return [ (t, s) @@ -295,7 +292,7 @@ class StateFilter(object): for s in state_keys ] - def get_member_split(self): + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching against non member state. @@ -307,7 +304,7 @@ class StateFilter(object): state caches). Returns: - tuple[StateFilter, StateFilter]: The member and non member filters + The member and non member filters """ if EventTypes.Member in self.types: @@ -340,6 +337,9 @@ class StateGroupStorage(object): """Given a state group try to return a previous group and a delta between the old and the new. + Args: + state_group: The state group used to retrieve state deltas. + Returns: Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: (prev_group, delta_ids) @@ -347,55 +347,59 @@ class StateGroupStorage(object): return self.stores.state.get_state_group_delta(state_group) - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): + async def get_state_groups_ids( + self, _room_id: str, event_ids: Iterable[str] + ) -> Dict[int, StateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events + _room_id: id of the room for these events + event_ids: ids of the events Returns: - Deferred[dict[int, StateMap[str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: return {} - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups(groups) + group_to_state = await self.stores.state._get_state_for_groups(groups) return group_to_state - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): + async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: """Get the event IDs of all the state in the given state group Args: - state_group (int) + state_group: A state group for which we want to get the state IDs. Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + Resolves to a map of (type, state_key) -> event_id """ - group_to_state = yield self._get_state_for_groups((state_group,)) + group_to_state = await self._get_state_for_groups((state_group,)) return group_to_state[state_group] - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): + async def get_state_groups( + self, room_id: str, event_ids: Iterable[str] + ) -> Dict[int, List[EventBase]]: """ Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. + dict of state_group_id -> list of state events. """ if not event_ids: return {} - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ ev_id for group_ids in group_to_ids.values() @@ -423,31 +427,34 @@ class StateGroupStorage(object): groups: list of state group IDs to query state_filter: The state filter used to fetch state from the database. + Returns: Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """Given a list of event_ids and type tuples, return a list of state dicts for each event. + Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], get_prev_content=False, ) @@ -463,24 +470,24 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_ids_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from event_id -> (type, state_key) -> event_id + A dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) @@ -491,36 +498,36 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events([event_id], state_filter) return state_map[event_id] - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_ids_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] def _get_state_for_groups( @@ -530,9 +537,8 @@ class StateGroupStorage(object): filtering by type/state_key Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. from the database. Returns: Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. @@ -540,18 +546,23 @@ class StateGroupStorage(object): return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[dict], + current_state_ids: dict, ): """Store a new set of state, returning a newly assigned state group. Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and `current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) + current_state_ids: The state to store. Map of (type, state_key) to event_id. Returns: diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index b9fafaa1a6..a6012c973d 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase): event = self.successResultOf(event) # Purge everything before this topological token - purge = storage.purge_events.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred( + storage.purge_events.purge_history(self.room_id, event, True) + ) self.pump() self.assertEqual(self.successResultOf(purge), None) @@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase): ) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) self.pump() f = self.failureResultOf(purge) self.assertIn("greater than forward", f.value.args[0]) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1d77b4a2d6..a5f250d477 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -97,8 +97,10 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.storage.persistence.persist_event( - self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + yield defer.ensureDeferred( + self.storage.persistence.persist_event( + self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index a0e133cd4a..6a48b9d3b3 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups_ids( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] @@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.storage.state.get_state_for_event(e5.event_id) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event(e5.event_id) + ) self.assertIsNotNone(e4) @@ -164,22 +168,28 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) self.assertStateMapEqual( @@ -188,12 +198,14 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: {self.u_alice.to_string()}}, - include_others=True, - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ), + ) ) self.assertStateMapEqual( @@ -206,11 +218,13 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True + ), + ) ) self.assertStateMapEqual( @@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.storage.state.get_state_groups_ids( - room_id, [e5.event_id] + group_ids = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) ) group = list(group_ids.keys())[0] diff --git a/tests/test_visibility.py b/tests/test_visibility.py index a7a36174ea..531a9b9118 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.store = self.hs.get_datastore() self.storage = self.hs.get_storage() - yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") + yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @defer.inlineCallbacks def test_filtering(self): @@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks diff --git a/tests/utils.py b/tests/utils.py index ac643679aa..b33b6860d4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -638,14 +638,8 @@ class DeferredMockCallable(object): ) -@defer.inlineCallbacks -def create_room(hs, room_id, creator_id): +async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room - - Args: - hs - room_id (str) - creator_id (str) """ persistence_store = hs.get_storage().persistence @@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id): event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() - yield store.store_room( + await store.store_room( room_id=room_id, room_creator_user_id=creator_id, is_public=False, @@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id): }, ) - event, context = yield defer.ensureDeferred( - event_creation_handler.create_new_client_event(builder) - ) + event, context = await event_creation_handler.create_new_client_event(builder) - yield persistence_store.persist_event(event, context) + await persistence_store.persist_event(event, context) diff --git a/tox.ini b/tox.ini index 595ab3ba66..a394f6eadc 100644 --- a/tox.ini +++ b/tox.ini @@ -206,6 +206,7 @@ commands = mypy \ synapse/storage/data_stores/main/ui_auth.py \ synapse/storage/database.py \ synapse/storage/engines \ + synapse/storage/state.py \ synapse/storage/util \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ -- cgit 1.5.1 From 8dff4a12424cda9e4abaa5f2905d58aa6e723777 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 29 Jul 2020 18:26:55 +0100 Subject: Re-implement unread counts (#7736) --- changelog.d/7736.feature | 1 + scripts/synapse_port_db | 2 +- synapse/handlers/sync.py | 6 + synapse/push/push_tools.py | 17 +-- synapse/rest/client/v2_alpha/sync.py | 1 + synapse/storage/data_stores/main/cache.py | 1 + synapse/storage/data_stores/main/events.py | 48 ++++++- synapse/storage/data_stores/main/events_worker.py | 86 ++++++++++- .../main/schema/delta/58/12unread_messages.sql | 18 +++ tests/rest/client/v1/utils.py | 20 +++ tests/rest/client/v2_alpha/test_sync.py | 157 ++++++++++++++++++++- 11 files changed, 339 insertions(+), 18 deletions(-) create mode 100644 changelog.d/7736.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql (limited to 'tests') diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature new file mode 100644 index 0000000000..c97864677a --- /dev/null +++ b/changelog.d/7736.feature @@ -0,0 +1 @@ +Add unread messages count to sync responses. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 22a6abd7d2..bee525197f 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -69,7 +69,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url"], + "events": ["processed", "outlier", "contains_url", "count_as_unread"], "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ebd3e98105..eaa4eeadf7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -103,6 +103,7 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) + unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -1886,6 +1887,10 @@ class SyncHandler(object): if room_builder.rtype == "joined": unread_notifications = {} # type: Dict[str, str] + + unread_count = await self.store.get_unread_message_count_for_user( + room_id, sync_config.user.to_string(), + ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1894,6 +1899,7 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, + unread_count=unread_count, ) if room_sync or always_include: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index d0145666bf..bc8f71916b 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,22 +21,13 @@ async def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") - badge = len(invites) for room_id in joins: - if room_id in my_receipts_by_room: - last_unread_event_id = my_receipts_by_room[room_id] - - notifs = await ( - store.get_unread_event_push_actions_by_room_for_user( - room_id, user_id, last_unread_event_id - ) - ) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + unread_count = await store.get_unread_message_count_for_user(room_id, user_id) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if unread_count else 0 return badge diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a5c24fbd63..3f5bf75e59 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index f39f556c20..edc3624fed 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) + self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 6f2e0d15cc..0c9c02afa1 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -53,6 +53,47 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) +STATE_EVENT_TYPES_TO_MARK_UNREAD = { + EventTypes.Topic, + EventTypes.Name, + EventTypes.RoomAvatar, + EventTypes.Tombstone, +} + + +def should_count_as_unread(event: EventBase, context: EventContext) -> bool: + # Exclude rejected and soft-failed events. + if context.rejected or event.internal_metadata.is_soft_failed(): + return False + + # Exclude notices. + if ( + not event.is_state() + and event.type == EventTypes.Message + and event.content.get("msgtype") == "m.notice" + ): + return False + + # Exclude edits. + relates_to = event.content.get("m.relates_to", {}) + if relates_to.get("rel_type") == RelationTypes.REPLACE: + return False + + # Mark events that have a non-empty string body as unread. + body = event.content.get("body") + if isinstance(body, str) and body: + return True + + # Mark some state events as unread. + if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: + return True + + # Mark encrypted events as unread. + if not event.is_state() and event.type == EventTypes.Encrypted: + return True + + return False + def encode_json(json_object): """ @@ -196,6 +237,10 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() + self.store.get_unread_message_count_for_user.invalidate_many( + (event.room_id,), + ) + for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -817,8 +862,9 @@ class PersistEventsStore: "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), + "count_as_unread": should_count_as_unread(event, context), } - for event, _ in events_and_contexts + for event, context in events_and_contexts ], ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index e812c67078..b03b259636 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database +from synapse.storage.types import Cursor from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import ( + Cache, + _CacheContext, + cached, + cachedInlineCallbacks, +) from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) + @cached(tree=True, cache_context=True) + async def get_unread_message_count_for_user( + self, room_id: str, user_id: str, cache_context: _CacheContext, + ) -> int: + """Retrieve the count of unread messages for the given room and user. + + Args: + room_id: The ID of the room to count unread messages in. + user_id: The ID of the user to count unread messages for. + + Returns: + The number of unread messages for the given user in the given room. + """ + with Measure(self._clock, "get_unread_message_count_for_user"): + last_read_event_id = await self.get_last_receipt_event_id_for_user( + user_id=user_id, + room_id=room_id, + receipt_type="m.read", + on_invalidate=cache_context.invalidate, + ) + + return await self.db.runInteraction( + "get_unread_message_count_for_user", + self._get_unread_message_count_for_user_txn, + user_id, + room_id, + last_read_event_id, + ) + + def _get_unread_message_count_for_user_txn( + self, + txn: Cursor, + user_id: str, + room_id: str, + last_read_event_id: Optional[str], + ) -> int: + if last_read_event_id: + # Get the stream ordering for the last read event. + stream_ordering = self.db.simple_select_one_onecol_txn( + txn=txn, + table="events", + keyvalues={"room_id": room_id, "event_id": last_read_event_id}, + retcol="stream_ordering", + ) + else: + # If there's no read receipt for that room, it probably means the user hasn't + # opened it yet, in which case use the stream ID of their join event. + # We can't just set it to 0 otherwise messages from other local users from + # before this user joined will be counted as well. + txn.execute( + """ + SELECT stream_ordering FROM local_current_membership + LEFT JOIN events USING (event_id, room_id) + WHERE membership = 'join' + AND user_id = ? + AND room_id = ? + """, + (user_id, room_id), + ) + row = txn.fetchone() + + if row is None: + return 0 + + stream_ordering = row[0] + + # Count the messages that qualify as unread after the stream ordering we've just + # retrieved. + sql = """ + SELECT COUNT(*) FROM events + WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread + """ + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + + return row[0] if row else 0 + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql new file mode 100644 index 0000000000..531b532c73 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Store a boolean value in the events table for whether the event should be counted in +-- the unread_count property of sync responses. +ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 22d734e763..7f8252330a 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -143,6 +143,26 @@ class RestHelper(object): return channel.json_body + def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id) + if tok: + path = path + "?access_token=%s" % tok + + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8") + ) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body + def _read_write_state( self, room_id: str, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index fa3a3ec1bd..a31e44c97e 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v2_alpha import read_marker, sync from tests import unittest from tests.server import TimedOutException @@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + request, channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: True): + """Syncs and compares the unread count with the expected value.""" + + request, channel = self.make_request( + "GET", self.url % self.next_batch, access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] -- cgit 1.5.1 From b3a97d6dac7f9f619b02e213bb8a745d65983d0d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 07:20:41 -0400 Subject: Convert some of the data store to async. (#7976) --- changelog.d/7976.misc | 1 + .../storage/data_stores/main/event_push_actions.py | 92 ++++++++++---------- synapse/storage/data_stores/main/room.py | 98 ++++++++++------------ synapse/storage/data_stores/main/state.py | 57 ++++++------- synapse/storage/data_stores/main/stats.py | 53 ++++++------ synapse/storage/data_stores/state/store.py | 37 ++++---- synapse/storage/state.py | 11 +-- tests/storage/test_event_push_actions.py | 12 ++- tests/storage/test_room.py | 24 +++--- tests/storage/test_state.py | 12 +-- 10 files changed, 190 insertions(+), 207 deletions(-) create mode 100644 changelog.d/7976.misc (limited to 'tests') diff --git a/changelog.d/7976.misc b/changelog.d/7976.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/7976.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 18297cf3b8..ad82838901 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -15,11 +15,10 @@ # limitations under the License. import logging +from typing import List from canonicaljson import json -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import Database @@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): return {"notify_count": notify_count, "highlight_count": highlight_count} - @defer.inlineCallbacks - def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): + async def get_push_action_users_in_range( + self, min_stream_ordering, max_stream_ordering + ): def f(txn): sql = ( "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" @@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = yield self.db.runInteraction("get_push_action_users_in_range", f) + ret = await self.db.runInteraction("get_push_action_users_in_range", f) return ret - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_http( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): + async def get_unread_push_actions_for_user_in_range_for_http( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the + max_stream_ordering: The inclusive upper bound on the stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. + limit: The maximum number of rows to return. Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions". + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions". The list will be ordered by ascending stream_ordering. The list will have between 0~limit entries. """ @@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.db.runInteraction( + after_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.db.runInteraction( + no_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore): # one of the subqueries may have hit the limit. return notifs[:limit] - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_email( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): + async def get_unread_push_actions_for_user_in_range_for_email( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the + max_stream_ordering: The inclusive upper bound on the stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. + limit: The maximum number of rows to return. Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions", "received_ts". + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". The list will be ordered by descending received_ts. The list will have between 0~limit entries. """ @@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.db.runInteraction( + after_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.db.runInteraction( + no_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -461,17 +465,13 @@ class EventPushActionsWorkerStore(SQLBaseStore): "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) - @defer.inlineCallbacks - def remove_push_actions_from_staging(self, event_id): + async def remove_push_actions_from_staging(self, event_id: str) -> None: """Called if we failed to persist the event to ensure that stale push actions don't build up in the DB - - Args: - event_id (str) """ try: - res = yield self.db.simple_delete( + res = await self.db.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end - @defer.inlineCallbacks - def get_time_of_last_push_action_before(self, stream_ordering): + async def get_time_of_last_push_action_before(self, stream_ordering): def f(txn): sql = ( "SELECT e.received_ts" @@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (stream_ordering,)) return txn.fetchone() - result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) + result = await self.db.runInteraction("get_time_of_last_push_action_before", f) return result[0] if result else None @@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): self._start_rotate_notifs, 30 * 60 * 1000 ) - @defer.inlineCallbacks - def get_push_actions_for_user( + async def get_push_actions_for_user( self, user_id, before=None, limit=50, only_highlight=False ): def f(txn): @@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore): txn.execute(sql, args) return self.db.cursor_to_dict(txn) - push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) + push_actions = await self.db.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - @defer.inlineCallbacks - def get_latest_push_action_stream_ordering(self): + async def get_latest_push_action_stream_ordering(self): def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.db.runInteraction( + result = await self.db.runInteraction( "get_latest_push_action_stream_ordering", f ) return result[0] or 0 @@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): def _start_rotate_notifs(self): return run_as_background_process("rotate_notifs", self._rotate_notifs) - @defer.inlineCallbacks - def _rotate_notifs(self): + async def _rotate_notifs(self): if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return self._doing_notif_rotation = True @@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore): while True: logger.info("Rotating notifications") - caught_up = yield self.db.runInteraction( + caught_up = await self.db.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: break - yield self.hs.get_clock().sleep(self._rotate_delay) + await self.hs.get_clock().sleep(self._rotate_delay) finally: self._doing_notif_rotation = False diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index d2e1e36e7f..ab48052cdc 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions @@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.search import SearchStore from synapse.storage.database import Database, LoggingTransaction from synapse.types import ThirdPartyInstanceID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore): return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) - @defer.inlineCallbacks - def get_largest_public_rooms( + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], search_filter: Optional[dict], @@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore): return results - ret_val = yield self.db.runInteraction( + ret_val = await self.db.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - defer.returnValue(ret_val) + return ret_val @cached(max_entries=10000) def is_room_blocked(self, room_id): @@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore): "get_rooms_paginate", _get_rooms_paginate_txn, ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): + @cached(max_entries=10000) + async def get_ratelimit_for_user(self, user_id): """Check if there are any overrides for ratelimiting for the given user @@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self.db.simple_select_one( + row = await self.db.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore): else: return None - @cachedInlineCallbacks() - def get_retention_policy_for_room(self, room_id): + @cached() + async def get_retention_policy_for_room(self, room_id): """Get the retention policy for a given room. If no retention policy has been found for this room, returns a policy defined @@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore): return self.db.cursor_to_dict(txn) - ret = yield self.db.runInteraction( + ret = await self.db.runInteraction( "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) # If we don't know this room ID, ret will be None, in this case return the default # policy. if not ret: - defer.returnValue( - { - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, - } - ) + return { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } row = ret[0] @@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore): if row["max_lifetime"] is None: row["max_lifetime"] = self.config.retention_default_max_lifetime - defer.returnValue(row) + return row def get_media_mxcs_in_room(self, room_id): """Retrieves all the local and remote media MXC URIs in a given room @@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_add_rooms_room_version_column, ) - @defer.inlineCallbacks - def _background_insert_retention(self, progress, batch_size): + async def _background_insert_retention(self, progress, batch_size): """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. NULLs the property's columns if missing from the retention event in the room's @@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore): else: return False - end = yield self.db.runInteraction( + end = await self.db.runInteraction( "insert_room_retention", _background_insert_retention_txn, ) if end: - yield self.db.updates._end_background_update("insert_room_retention") + await self.db.updates._end_background_update("insert_room_retention") - defer.returnValue(batch_size) + return batch_size async def _background_add_rooms_room_version_column( self, progress: dict, batch_size: int @@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - @defer.inlineCallbacks - def store_room( + async def store_room( self, room_id: str, room_creator_user_id: str, @@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) + await self.db.runInteraction("store_room_txn", store_room_txn, next_id) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - @defer.inlineCallbacks - def set_room_is_public(self, room_id, is_public): + async def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): self.db.simple_update_one_txn( txn, @@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction( + await self.db.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() - @defer.inlineCallbacks - def set_room_is_public_appservice( + async def set_room_is_public_appservice( self, room_id, appservice_id, network_id, is_public ): """Edit the appservice/network specific public room list. @@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction( + await self.db.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, next_id, @@ -1327,52 +1318,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - @defer.inlineCallbacks - def block_room(self, room_id, user_id): + async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. Can be called multiple times. Args: - room_id (str): Room to block - user_id (str): Who blocked it - - Returns: - Deferred + room_id: Room to block + user_id: Who blocked it """ - yield self.db.simple_upsert( + await self.db.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, insertion_values={"user_id": user_id}, desc="block_room", ) - yield self.db.runInteraction( + await self.db.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, self.is_room_blocked, (room_id,), ) - @defer.inlineCallbacks - def get_rooms_for_retention_period_in_range( - self, min_ms, max_ms, include_null=False - ): + async def get_rooms_for_retention_period_in_range( + self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False + ) -> Dict[str, dict]: """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. Args: - min_ms (int|None): Duration in milliseconds that define the lower limit of + min_ms: Duration in milliseconds that define the lower limit of the range to handle (exclusive). If None, doesn't set a lower limit. - max_ms (int|None): Duration in milliseconds that define the upper limit of + max_ms: Duration in milliseconds that define the upper limit of the range to handle (inclusive). If None, doesn't set an upper limit. - include_null (bool): Whether to include rooms which retention policy is NULL + include_null: Whether to include rooms which retention policy is NULL in the returned set. Returns: - dict[str, dict]: The rooms within this range, along with their retention - policy. The key is "room_id", and maps to a dict describing the retention - policy associated with this room ID. The keys for this nested dict are - "min_lifetime" (int|None), and "max_lifetime" (int|None). + The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). """ def get_rooms_for_retention_period_in_range_txn(txn): @@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return rooms_dict - rooms = yield self.db.runInteraction( + rooms = await self.db.runInteraction( "get_rooms_for_retention_period_in_range", get_rooms_for_retention_period_in_range_txn, ) - defer.returnValue(rooms) + return rooms diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index bb38a04ede..a360699408 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -16,12 +16,12 @@ import collections.abc import logging from collections import namedtuple - -from twisted.internet import defer +from typing import Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore @@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = await self.get_create_event_for_room(room_id) return create_event.content.get("room_version", "1") - @defer.inlineCallbacks - def get_room_predecessor(self, room_id): + async def get_room_predecessor(self, room_id: str) -> Optional[dict]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. Args: - room_id (str) + room_id: The room ID. Returns: - Deferred[dict|None]: A dictionary containing the structure of the predecessor - field from the room's create event. The structure is subject to other servers, - but it is expected to be: - * room_id (str): The room ID of the predecessor room - * event_id (str): The ID of the tombstone event in the predecessor room + A dictionary containing the structure of the predecessor + field from the room's create event. The structure is subject to other servers, + but it is expected to be: + * room_id (str): The room ID of the predecessor room + * event_id (str): The ID of the tombstone event in the predecessor room - None if a predecessor key is not found, or is not a dictionary. + None if a predecessor key is not found, or is not a dictionary. Raises: NotFoundError if the given room is unknown """ # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) + create_event = await self.get_create_event_for_room(room_id) # Retrieve the predecessor key of the create event predecessor = create_event.content.get("predecessor", None) @@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return predecessor - @defer.inlineCallbacks - def get_create_event_for_room(self, room_id): + async def get_create_event_for_room(self, room_id: str) -> EventBase: """Get the create state event for a room. Args: - room_id (str) + room_id: The room ID. Returns: - Deferred[EventBase]: The room creation event. + The room creation event. Raises: NotFoundError if the room is unknown """ - state_ids = yield self.get_current_state_ids(room_id) + state_ids = await self.get_current_state_ids(room_id) create_id = state_ids.get((EventTypes.Create, "")) # If we can't find the create event, assume we've hit a dead end @@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return - create_event = yield self.get_event(create_id) + create_event = await self.get_event(create_id) return create_event @cached(max_entries=100000, iterable=True) @@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) - @defer.inlineCallbacks - def get_canonical_alias_for_room(self, room_id): + async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: """Get canonical alias for room, if any Args: - room_id (str) + room_id: The room ID Returns: - Deferred[str|None]: The canonical alias, if any + The canonical alias, if any """ - state = yield self.get_filtered_current_state_ids( + state = await self.get_filtered_current_state_ids( room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) ) @@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not event_id: return - event = yield self.get_event(event_id, allow_none=True) + event = await self.get_event(event_id, allow_none=True) if not event: return @@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {row["event_id"]: row["state_group"] for row in rows} - @defer.inlineCallbacks - def get_referenced_state_groups(self, state_groups): + async def get_referenced_state_groups( + self, state_groups: Iterable[int] + ) -> Set[int]: """Check if the state groups are referenced by events. Args: - state_groups (Iterable[int]) + state_groups Returns: - Deferred[set[int]]: The subset of state groups that are - referenced. + The subset of state groups that are referenced. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 380c1ec7da..922400a7c3 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -16,8 +16,8 @@ import logging from itertools import chain +from typing import Tuple -from twisted.internet import defer from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership @@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore): """ return (ts // self.stats_bucket_size) * self.stats_bucket_size - @defer.inlineCallbacks - def _populate_stats_process_users(self, progress, batch_size): + async def _populate_stats_process_users(self, progress, batch_size): """ This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self.db.updates._end_background_update("populate_stats_process_users") + await self.db.updates._end_background_update("populate_stats_process_users") return 1 last_user_id = progress.get("last_user_id", "") @@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = yield self.db.runInteraction( + users_to_work_on = await self.db.runInteraction( "_populate_stats_process_users", _get_next_batch ) # No more rooms -- complete the transaction. if not users_to_work_on: - yield self.db.updates._end_background_update("populate_stats_process_users") + await self.db.updates._end_background_update("populate_stats_process_users") return 1 for user_id in users_to_work_on: - yield self._calculate_and_set_initial_state_for_user(user_id) + await self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - yield self.db.runInteraction( + await self.db.runInteraction( "populate_stats_process_users", self.db.updates._background_update_progress_txn, "populate_stats_process_users", @@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore): return len(users_to_work_on) - @defer.inlineCallbacks - def _populate_stats_process_rooms(self, progress, batch_size): + async def _populate_stats_process_rooms(self, progress, batch_size): """ This is a background update which regenerates statistics for rooms. """ if not self.stats_enabled: - yield self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db.updates._end_background_update("populate_stats_process_rooms") return 1 last_room_id = progress.get("last_room_id", "") @@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = yield self.db.runInteraction( + rooms_to_work_on = await self.db.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db.updates._end_background_update("populate_stats_process_rooms") return 1 for room_id in rooms_to_work_on: - yield self._calculate_and_set_initial_state_for_room(room_id) + await self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - yield self.db.runInteraction( + await self.db.runInteraction( "_populate_stats_process_rooms", self.db.updates._background_update_progress_txn, "populate_stats_process_rooms", @@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore): return room_deltas, user_deltas - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_room(self, room_id): + async def _calculate_and_set_initial_state_for_room( + self, room_id: str + ) -> Tuple[dict, dict, int]: """Calculate and insert an entry into room_stats_current. Args: - room_id (str) + room_id: The room ID under calculation. Returns: - Deferred[tuple[dict, dict, int]]: A tuple of room state, membership - counts and stream position. + A tuple of room state, membership counts and stream position. """ def _fetch_current_state_stats(txn): @@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore): current_state_events_count, users_in_room, pos, - ) = yield self.db.runInteraction( + ) = await self.db.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) - state_event_map = yield self.get_events(event_ids, get_prev_content=False) + state_event_map = await self.get_events(event_ids, get_prev_content=False) room_state = { "join_rules": None, @@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore): event.content.get("m.federate", True) is True ) - yield self.update_room_state(room_id, room_state) + await self.update_room_state(room_id, room_state) local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] - yield self.update_stats_delta( + await self.update_stats_delta( ts=self.clock.time_msec(), stats_type="room", stats_id=room_id, @@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore): }, ) - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_user(self, user_id): + async def _calculate_and_set_initial_state_for_user(self, user_id): def _calculate_and_set_initial_state_for_user_txn(txn): pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) @@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore): (count,) = txn.fetchone() return count, pos - joined_rooms, pos = yield self.db.runInteraction( + joined_rooms, pos = await self.db.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) - yield self.update_stats_delta( + await self.update_stats_delta( ts=self.clock.time_msec(), stats_type="user", stats_id=user_id, diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 128c09a2cf..7dada7f75f 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "get_state_group_delta", _get_state_group_delta_txn ) - @defer.inlineCallbacks - def _get_state_groups_from_groups( + async def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Dict[int, StateMap[str]]: """Returns the state groups for a given set of groups from the database, filtering on types of state events. @@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_filter: The state filter used to fetch state from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ results = {} chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: - res = yield self.db.runInteraction( + res = await self.db.runInteraction( "_get_state_groups_from_groups", self._get_state_groups_from_groups_txn, chunk, @@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types - @defer.inlineCallbacks - def _get_state_for_groups( + async def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[int, StateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_filter: The state filter used to fetch state from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ member_filter, non_member_filter = state_filter.get_member_split() @@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ( non_member_state, incomplete_groups_nm, - ) = yield self._get_state_for_groups_using_cache( + ) = self._get_state_for_groups_using_cache( groups, self._state_group_cache, state_filter=non_member_filter ) - ( - member_state, - incomplete_groups_m, - ) = yield self._get_state_for_groups_using_cache( + (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( groups, self._state_group_members_cache, state_filter=member_filter ) @@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # Help the cache hit ratio by expanding the filter a bit db_state_filter = state_filter.return_expanded() - group_to_state_dict = yield self._get_state_groups_from_groups( + group_to_state_dict = await self._get_state_groups_from_groups( list(incomplete_groups), state_filter=db_state_filter ) @@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ((sg,) for sg in state_groups_to_delete), ) - @defer.inlineCallbacks - def get_previous_state_groups(self, state_groups): + async def get_previous_state_groups( + self, state_groups: Iterable[int] + ) -> Dict[int, int]: """Fetch the previous groups of the given state groups. Args: - state_groups (Iterable[int]) + state_groups Returns: - Deferred[dict[int, int]]: mapping from state group to previous - state group. + A mapping from state group to previous state group. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 49ee9c9a74..534883361f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar +from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr @@ -419,7 +419,7 @@ class StateGroupStorage(object): def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Returns the state groups for a given set of groups, filtering on types of state events. @@ -429,7 +429,7 @@ class StateGroupStorage(object): from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) @@ -532,7 +532,7 @@ class StateGroupStorage(object): def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -540,8 +540,9 @@ class StateGroupStorage(object): groups: list of state groups for which we want to get the state. state_filter: The state filter used to fetch state. from the database. + Returns: - Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_for_groups(groups, state_filter) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 43dbeb42c5..2b1580feeb 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_http( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_http( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_email(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_email( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_email( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index a5f250d477..d07b985a8e 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id=self.u_creator.to_string(), - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id=self.u_creator.to_string(), + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks @@ -88,11 +90,13 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 6a48b9d3b3..8bd12fa847 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks -- cgit 1.5.1 From 4cce8ef74ec233d8e49361bee705f2e38de2e11e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 07:27:39 -0400 Subject: Convert appservice to async. (#7973) --- changelog.d/7973.misc | 1 + synapse/appservice/__init__.py | 31 ++++++------- synapse/appservice/api.py | 21 ++++----- synapse/appservice/scheduler.py | 49 +++++++++----------- synapse/handlers/appservice.py | 10 ++--- tests/appservice/test_appservice.py | 89 +++++++++++++++++++++++++------------ tests/appservice/test_scheduler.py | 19 ++++---- tests/handlers/test_appservice.py | 5 ++- 8 files changed, 122 insertions(+), 103 deletions(-) create mode 100644 changelog.d/7973.misc (limited to 'tests') diff --git a/changelog.d/7973.misc b/changelog.d/7973.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/7973.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 0323256472..1ffdc1ed95 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -15,11 +15,9 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.types import GroupID, get_domain_from_id -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -43,7 +41,7 @@ class AppServiceTransaction(object): Args: as_api(ApplicationServiceApi): The API to use to send. Returns: - A Deferred which resolves to True if the transaction was sent. + An Awaitable which resolves to True if the transaction was sent. """ return as_api.push_bulk( service=self.service, events=self.events, txn_id=self.id @@ -172,8 +170,7 @@ class ApplicationService(object): return regex_obj["exclusive"] return False - @defer.inlineCallbacks - def _matches_user(self, event, store): + async def _matches_user(self, event, store): if not event: return False @@ -188,12 +185,12 @@ class ApplicationService(object): if not store: return False - does_match = yield self._matches_user_in_member_list(event.room_id, store) + does_match = await self._matches_user_in_member_list(event.room_id, store) return does_match - @cachedInlineCallbacks(num_args=1, cache_context=True) - def _matches_user_in_member_list(self, room_id, store, cache_context): - member_list = yield store.get_users_in_room( + @cached(num_args=1, cache_context=True) + async def _matches_user_in_member_list(self, room_id, store, cache_context): + member_list = await store.get_users_in_room( room_id, on_invalidate=cache_context.invalidate ) @@ -208,35 +205,33 @@ class ApplicationService(object): return self.is_interested_in_room(event.room_id) return False - @defer.inlineCallbacks - def _matches_aliases(self, event, store): + async def _matches_aliases(self, event, store): if not store or not event: return False - alias_list = yield store.get_aliases_for_room(event.room_id) + alias_list = await store.get_aliases_for_room(event.room_id) for alias in alias_list: if self.is_interested_in_alias(alias): return True return False - @defer.inlineCallbacks - def is_interested(self, event, store=None): + async def is_interested(self, event, store=None) -> bool: """Check if this service is interested in this event. Args: event(Event): The event to check. store(DataStore) Returns: - bool: True if this service would like to know about this event. + True if this service would like to know about this event. """ # Do cheap checks first if self._matches_room_id(event): return True - if (yield self._matches_aliases(event, store)): + if await self._matches_aliases(event, store): return True - if (yield self._matches_user(event, store)): + if await self._matches_user(event, store): return True return False diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 1e0e4d497d..db578bda79 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -93,13 +93,12 @@ class ApplicationServiceApi(SimpleHttpClient): hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS ) - @defer.inlineCallbacks - def query_user(self, service, user_id): + async def query_user(self, service, user_id): if service.url is None: return False uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) try: - response = yield self.get_json(uri, {"access_token": service.hs_token}) + response = await self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object return True except CodeMessageException as e: @@ -110,14 +109,12 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_user to %s threw exception %s", uri, ex) return False - @defer.inlineCallbacks - def query_alias(self, service, alias): + async def query_alias(self, service, alias): if service.url is None: return False uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) - response = None try: - response = yield self.get_json(uri, {"access_token": service.hs_token}) + response = await self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object return True except CodeMessageException as e: @@ -128,8 +125,7 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_alias to %s threw exception %s", uri, ex) return False - @defer.inlineCallbacks - def query_3pe(self, service, kind, protocol, fields): + async def query_3pe(self, service, kind, protocol, fields): if kind == ThirdPartyEntityKind.USER: required_field = "userid" elif kind == ThirdPartyEntityKind.LOCATION: @@ -146,7 +142,7 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.parse.quote(protocol), ) try: - response = yield self.get_json(uri, fields) + response = await self.get_json(uri, fields) if not isinstance(response, list): logger.warning( "query_3pe to %s returned an invalid response %r", uri, response @@ -202,8 +198,7 @@ class ApplicationServiceApi(SimpleHttpClient): key = (service.id, protocol) return self.protocol_meta_cache.wrap(key, _get) - @defer.inlineCallbacks - def push_bulk(self, service, events, txn_id=None): + async def push_bulk(self, service, events, txn_id=None): if service.url is None: return True @@ -218,7 +213,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) try: - yield self.put_json( + await self.put_json( uri=uri, json_body={"events": events}, args={"access_token": service.hs_token}, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 9998f822f1..d5204b1314 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -50,8 +50,6 @@ components. """ import logging -from twisted.internet import defer - from synapse.appservice import ApplicationServiceState from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -73,12 +71,11 @@ class ApplicationServiceScheduler(object): self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) - @defer.inlineCallbacks - def start(self): + async def start(self): logger.info("Starting appservice scheduler") # check for any DOWN ASes and start recoverers for them. - services = yield self.store.get_appservices_by_state( + services = await self.store.get_appservices_by_state( ApplicationServiceState.DOWN ) @@ -117,8 +114,7 @@ class _ServiceQueuer(object): "as-sender-%s" % (service.id,), self._send_request, service ) - @defer.inlineCallbacks - def _send_request(self, service): + async def _send_request(self, service): # sanity-check: we shouldn't get here if this service already has a sender # running. assert service.id not in self.requests_in_flight @@ -130,7 +126,7 @@ class _ServiceQueuer(object): if not events: return try: - yield self.txn_ctrl.send(service, events) + await self.txn_ctrl.send(service, events) except Exception: logger.exception("AS request failed") finally: @@ -162,36 +158,33 @@ class _TransactionController(object): # for UTs self.RECOVERER_CLASS = _Recoverer - @defer.inlineCallbacks - def send(self, service, events): + async def send(self, service, events): try: - txn = yield self.store.create_appservice_txn(service=service, events=events) - service_is_up = yield self._is_service_up(service) + txn = await self.store.create_appservice_txn(service=service, events=events) + service_is_up = await self._is_service_up(service) if service_is_up: - sent = yield txn.send(self.as_api) + sent = await txn.send(self.as_api) if sent: - yield txn.complete(self.store) + await txn.complete(self.store) else: run_in_background(self._on_txn_fail, service) except Exception: logger.exception("Error creating appservice transaction") run_in_background(self._on_txn_fail, service) - @defer.inlineCallbacks - def on_recovered(self, recoverer): + async def on_recovered(self, recoverer): logger.info( "Successfully recovered application service AS ID %s", recoverer.service.id ) self.recoverers.pop(recoverer.service.id) logger.info("Remaining active recoverers: %s", len(self.recoverers)) - yield self.store.set_appservice_state( + await self.store.set_appservice_state( recoverer.service, ApplicationServiceState.UP ) - @defer.inlineCallbacks - def _on_txn_fail(self, service): + async def _on_txn_fail(self, service): try: - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + await self.store.set_appservice_state(service, ApplicationServiceState.DOWN) self.start_recoverer(service) except Exception: logger.exception("Error starting AS recoverer") @@ -211,9 +204,8 @@ class _TransactionController(object): recoverer.recover() logger.info("Now %i active recoverers", len(self.recoverers)) - @defer.inlineCallbacks - def _is_service_up(self, service): - state = yield self.store.get_appservice_state(service) + async def _is_service_up(self, service): + state = await self.store.get_appservice_state(service) return state == ApplicationServiceState.UP or state is None @@ -254,25 +246,24 @@ class _Recoverer(object): self.backoff_counter += 1 self.recover() - @defer.inlineCallbacks - def retry(self): + async def retry(self): logger.info("Starting retries on %s", self.service.id) try: while True: - txn = yield self.store.get_oldest_unsent_txn(self.service) + txn = await self.store.get_oldest_unsent_txn(self.service) if not txn: # nothing left: we're done! - self.callback(self) + await self.callback(self) return logger.info( "Retrying transaction %s for AS ID %s", txn.id, txn.service.id ) - sent = yield txn.send(self.as_api) + sent = await txn.send(self.as_api) if not sent: break - yield txn.complete(self.store) + await txn.complete(self.store) # reset the backoff counter and then process the next transaction self.backoff_counter = 1 diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 92d4c6e16c..fbc56c351b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -27,7 +27,6 @@ from synapse.metrics import ( event_processing_loop_room_count, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util import log_failure from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -100,10 +99,11 @@ class ApplicationServicesHandler(object): if not self.started_scheduler: - def start_scheduler(): - return self.scheduler.start().addErrback( - log_failure, "Application Services Failure" - ) + async def start_scheduler(): + try: + return self.scheduler.start() + except Exception: + logger.error("Application Services Failure") run_as_background_process("as_scheduler", start_scheduler) self.started_scheduler = True diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 4003869ed6..236b608d58 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase): def test_regex_user_id_prefix_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_user_id_prefix_no_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" - self.assertFalse((yield self.service.is_interested(self.event))) + self.assertFalse( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_member_is_checked(self): @@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_id_match(self): @@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_id_no_match(self): @@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" - self.assertFalse((yield self.service.is_interested(self.event))) + self.assertFalse( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_alias_match(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = [ - "#irc_foobar:matrix.org", - "#athing:matrix.org", - ] - self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#irc_foobar:matrix.org", "#athing:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertTrue( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) def test_non_exclusive_alias(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( @@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = [ - "#xmpp_foobar:matrix.org", - "#athing:matrix.org", - ] - self.store.get_users_in_room.return_value = [] - self.assertFalse((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertFalse( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) @defer.inlineCallbacks def test_regex_multiple_matches(self): @@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] - self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#irc_barfoo:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertTrue( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) @defer.inlineCallbacks def test_interested_in_self(self): @@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.type = "m.room.member" self.event.content = {"membership": "invite"} self.event.state_key = self.service.sender - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_member_list_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) - self.store.get_users_in_room.return_value = [ - "@alice:here", - "@irc_fo:here", # AS user - "@bob:here", - ] - self.store.get_aliases_for_room.return_value = [] + # Note that @irc_fo:here is the AS user. + self.store.get_users_in_room.return_value = defer.succeed( + ["@alice:here", "@irc_fo:here", "@bob:here"] + ) + self.store.get_aliases_for_room.return_value = defer.succeed([]) self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( - (yield self.service.is_interested(event=self.event, store=self.store)) + ( + yield defer.ensureDeferred( + self.service.is_interested(event=self.event, store=self.store) + ) + ) ) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 52f89d3f83..68a4caabbf 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -25,6 +25,7 @@ from synapse.appservice.scheduler import ( from synapse.logging.context import make_deferred_yieldable from tests import unittest +from tests.test_utils import make_awaitable from ..utils import MockClock @@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.get_appservice_state = Mock( return_value=defer.succeed(ApplicationServiceState.UP) ) - txn.send = Mock(return_value=defer.succeed(True)) + txn.send = Mock(return_value=make_awaitable(True)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events # txn made and saved @@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events # txn made and saved @@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): return_value=defer.succeed(ApplicationServiceState.UP) ) self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) - txn.send = Mock(return_value=defer.succeed(False)) # fails to send + txn.send = Mock(return_value=make_awaitable(False)) # fails to send self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events @@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=True) + txn.send = Mock(return_value=make_awaitable(True)) + txn.complete.return_value = make_awaitable(None) # wait for exp backoff self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) @@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=False) + txn.send = Mock(return_value=make_awaitable(False)) + txn.complete.return_value = make_awaitable(None) self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) @@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEquals(3, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, self.callback.call_count) - txn.send = Mock(return_value=True) # successfully send the txn + txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEquals(1, txn.send.call_count) # new mock reset call count diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index ebabe9a7d6..628f7d8db0 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.handlers.appservice import ApplicationServicesHandler +from tests.test_utils import make_awaitable from tests.utils import MockClock from .. import unittest @@ -117,7 +118,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice_alias(is_interested_in_alias=False), ] - self.mock_as_api.query_alias.return_value = defer.succeed(True) + self.mock_as_api.query_alias.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services self.mock_store.get_association_from_room_alias.return_value = defer.succeed( Mock(room_id=room_id, servers=servers) @@ -135,7 +136,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def _mkservice(self, is_interested): service = Mock() - service.is_interested.return_value = defer.succeed(is_interested) + service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" return service -- cgit 1.5.1 From c978f6c4515a631f289aedb1844d8579b9334aaa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 08:01:33 -0400 Subject: Convert federation client to async/await. (#7975) --- changelog.d/7975.misc | 1 + contrib/cmdclient/console.py | 16 ++-- synapse/crypto/keyring.py | 60 +++++++------- synapse/federation/federation_client.py | 8 +- synapse/federation/sender/__init__.py | 19 ++--- synapse/federation/transport/client.py | 96 ++++++++++------------- synapse/handlers/groups_local.py | 35 ++++----- synapse/http/matrixfederationclient.py | 72 ++++++++--------- tests/crypto/test_keyring.py | 11 +-- tests/federation/test_complexity.py | 21 ++--- tests/federation/test_federation_sender.py | 10 +-- tests/handlers/test_directory.py | 5 +- tests/handlers/test_profile.py | 3 +- tests/http/test_fedclient.py | 50 ++++++++---- tests/replication/test_federation_sender_shard.py | 13 ++- tests/rest/admin/test_admin.py | 4 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +- tests/test_federation.py | 2 +- 18 files changed, 209 insertions(+), 221 deletions(-) create mode 100644 changelog.d/7975.misc (limited to 'tests') diff --git a/changelog.d/7975.misc b/changelog.d/7975.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/7975.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 77422f5e5d..dfc1d294dc 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -609,13 +609,15 @@ class SynapseCmd(cmd.Cmd): @defer.inlineCallbacks def _do_event_stream(self, timeout): - res = yield self.http_client.get_json( - self._url() + "/events", - { - "access_token": self._tok(), - "timeout": str(timeout), - "from": self.event_stream_token, - }, + res = yield defer.ensureDeferred( + self.http_client.get_json( + self._url() + "/events", + { + "access_token": self._tok(), + "timeout": str(timeout), + "from": self.event_stream_token, + }, + ) ) print(json.dumps(res, indent=4)) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index dbfc3e8972..443cde0b6d 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -632,18 +632,20 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) try: - query_response = yield self.client.post_json( - destination=perspective_name, - path="/_matrix/key/v2/query", - data={ - "server_keys": { - server_name: { - key_id: {"minimum_valid_until_ts": min_valid_ts} - for key_id, min_valid_ts in server_keys.items() + query_response = yield defer.ensureDeferred( + self.client.post_json( + destination=perspective_name, + path="/_matrix/key/v2/query", + data={ + "server_keys": { + server_name: { + key_id: {"minimum_valid_until_ts": min_valid_ts} + for key_id, min_valid_ts in server_keys.items() + } + for server_name, server_keys in keys_to_fetch.items() } - for server_name, server_keys in keys_to_fetch.items() - } - }, + }, + ) ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve upon @@ -792,23 +794,25 @@ class ServerKeyFetcher(BaseV2KeyFetcher): time_now_ms = self.clock.time_msec() try: - response = yield self.client.get_json( - destination=server_name, - path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id), - ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an - # easy request to handle, so if it doesn't reply within 10s, it's - # probably not going to. - # - # Furthermore, when we are acting as a notary server, we cannot - # wait all day for all of the origin servers, as the requesting - # server will otherwise time out before we can respond. - # - # (Note that get_json may make 4 attempts, so this can still take - # almost 45 seconds to fetch the headers, plus up to another 60s to - # read the response). - timeout=10000, + response = yield defer.ensureDeferred( + self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server/" + + urllib.parse.quote(requested_key_id), + ignore_backoff=True, + # we only give the remote server 10s to respond. It should be an + # easy request to handle, so if it doesn't reply within 10s, it's + # probably not going to. + # + # Furthermore, when we are acting as a notary server, we cannot + # wait all day for all of the origin servers, as the requesting + # server will otherwise time out before we can respond. + # + # (Note that get_json may make 4 attempts, so this can still take + # almost 45 seconds to fetch the headers, plus up to another 60s to + # read the response). + timeout=10000, + ) ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 994e6c8d5a..38ac7ec699 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -135,7 +135,7 @@ class FederationClient(FederationBase): and try the request anyway. Returns: - a Deferred which will eventually yield a JSON object from the + a Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels(query_type).inc() @@ -157,7 +157,7 @@ class FederationClient(FederationBase): content (dict): The query content. Returns: - a Deferred which will eventually yield a JSON object from the + an Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels("client_device_keys").inc() @@ -180,7 +180,7 @@ class FederationClient(FederationBase): content (dict): The query content. Returns: - a Deferred which will eventually yield a JSON object from the + an Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() @@ -900,7 +900,7 @@ class FederationClient(FederationBase): party instance Returns: - Deferred[Dict[str, Any]]: The response from the remote server, or None if + Awaitable[Dict[str, Any]]: The response from the remote server, or None if `remote_server` is the same as the local server_name Raises: diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index ba4ddd2370..8f549ae6ee 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -288,8 +288,7 @@ class FederationSender(object): for destination in destinations: self._get_per_destination_queue(destination).send_pdu(pdu, order) - @defer.inlineCallbacks - def send_read_receipt(self, receipt: ReadReceipt): + async def send_read_receipt(self, receipt: ReadReceipt) -> None: """Send a RR to any other servers in the room Args: @@ -330,9 +329,7 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = yield defer.ensureDeferred( - self.state.get_current_hosts_in_room(room_id) - ) + domains = await self.state.get_current_hosts_in_room(room_id) domains = [ d for d in domains @@ -387,8 +384,7 @@ class FederationSender(object): queue.flush_read_receipts_for_room(room_id) @preserve_fn # the caller should not yield on this - @defer.inlineCallbacks - def send_presence(self, states: List[UserPresenceState]): + async def send_presence(self, states: List[UserPresenceState]): """Send the new presence states to the appropriate destinations. This actually queues up the presence states ready for sending and @@ -423,7 +419,7 @@ class FederationSender(object): if not states_map: break - yield self._process_presence_inner(list(states_map.values())) + await self._process_presence_inner(list(states_map.values())) except Exception: logger.exception("Error sending presence states to servers") finally: @@ -450,14 +446,11 @@ class FederationSender(object): self._get_per_destination_queue(destination).send_presence(states) @measure_func("txnqueue._process_presence") - @defer.inlineCallbacks - def _process_presence_inner(self, states: List[UserPresenceState]): + async def _process_presence_inner(self, states: List[UserPresenceState]): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination """ - hosts_and_states = yield defer.ensureDeferred( - get_interested_remotes(self.store, states, self.state) - ) + hosts_and_states = await get_interested_remotes(self.store, states, self.state) for destinations, states in hosts_and_states: for destination in destinations: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index cfdf23d366..9ea821dbb2 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -18,8 +18,6 @@ import logging import urllib from typing import Any, Dict, Optional -from twisted.internet import defer - from synapse.api.constants import Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.urls import ( @@ -51,7 +49,7 @@ class TransportLayerClient(object): event_id (str): The event we want the context at. Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) @@ -75,7 +73,7 @@ class TransportLayerClient(object): giving up. None indicates no timeout. Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) @@ -96,7 +94,7 @@ class TransportLayerClient(object): limit (int) Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug( "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", @@ -118,16 +116,15 @@ class TransportLayerClient(object): destination, path=path, args=args, try_trailing_slash_on_400=True ) - @defer.inlineCallbacks @log_function - def send_transaction(self, transaction, json_data_callback=None): + async def send_transaction(self, transaction, json_data_callback=None): """ Sends the given Transaction to its destination Args: transaction (Transaction) Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Fails with ``HTTPRequestException`` if we get an HTTP response @@ -154,7 +151,7 @@ class TransportLayerClient(object): path = _create_v1_path("/send/%s", transaction.transaction_id) - response = yield self.client.put_json( + response = await self.client.put_json( transaction.destination, path=path, data=json_data, @@ -166,14 +163,13 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def make_query( + async def make_query( self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False ): path = _create_v1_path("/query/%s", query_type) - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, args=args, @@ -184,9 +180,10 @@ class TransportLayerClient(object): return content - @defer.inlineCallbacks @log_function - def make_membership_event(self, destination, room_id, user_id, membership, params): + async def make_membership_event( + self, destination, room_id, user_id, membership, params + ): """Asks a remote server to build and sign us a membership event Note that this does not append any events to any graphs. @@ -200,7 +197,7 @@ class TransportLayerClient(object): request. Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body (ie, the new event). Fails with ``HTTPRequestException`` if we get an HTTP response @@ -231,7 +228,7 @@ class TransportLayerClient(object): ignore_backoff = True retry_on_dns_fail = True - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, args=params, @@ -242,34 +239,31 @@ class TransportLayerClient(object): return content - @defer.inlineCallbacks @log_function - def send_join_v1(self, destination, room_id, event_id, content): + async def send_join_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content ) return response - @defer.inlineCallbacks @log_function - def send_join_v2(self, destination, room_id, event_id, content): + async def send_join_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/send_join/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content ) return response - @defer.inlineCallbacks @log_function - def send_leave_v1(self, destination, room_id, event_id, content): + async def send_leave_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, @@ -282,12 +276,11 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def send_leave_v2(self, destination, room_id, event_id, content): + async def send_leave_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, @@ -300,31 +293,28 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def send_invite_v1(self, destination, room_id, event_id, content): + async def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) return response - @defer.inlineCallbacks @log_function - def send_invite_v2(self, destination, room_id, event_id, content): + async def send_invite_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/invite/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) return response - @defer.inlineCallbacks @log_function - def get_public_rooms( + async def get_public_rooms( self, remote_server: str, limit: Optional[int] = None, @@ -355,7 +345,7 @@ class TransportLayerClient(object): data["filter"] = search_filter try: - response = yield self.client.post_json( + response = await self.client.post_json( destination=remote_server, path=path, data=data, ignore_backoff=True ) except HttpResponseException as e: @@ -381,7 +371,7 @@ class TransportLayerClient(object): args["since"] = [since_token] try: - response = yield self.client.get_json( + response = await self.client.get_json( destination=remote_server, path=path, args=args, ignore_backoff=True ) except HttpResponseException as e: @@ -396,29 +386,26 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def exchange_third_party_invite(self, destination, room_id, event_dict): + async def exchange_third_party_invite(self, destination, room_id, event_dict): path = _create_v1_path("/exchange_third_party_invite/%s", room_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=event_dict ) return response - @defer.inlineCallbacks @log_function - def get_event_auth(self, destination, room_id, event_id): + async def get_event_auth(self, destination, room_id, event_id): path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) - content = yield self.client.get_json(destination=destination, path=path) + content = await self.client.get_json(destination=destination, path=path) return content - @defer.inlineCallbacks @log_function - def query_client_keys(self, destination, query_content, timeout): + async def query_client_keys(self, destination, query_content, timeout): """Query the device keys for a list of user ids hosted on a remote server. @@ -453,14 +440,13 @@ class TransportLayerClient(object): """ path = _create_v1_path("/user/keys/query") - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def query_user_devices(self, destination, user_id, timeout): + async def query_user_devices(self, destination, user_id, timeout): """Query the devices for a user id hosted on a remote server. Response: @@ -493,14 +479,13 @@ class TransportLayerClient(object): """ path = _create_v1_path("/user/devices/%s", user_id) - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def claim_client_keys(self, destination, query_content, timeout): + async def claim_client_keys(self, destination, query_content, timeout): """Claim one-time keys for a list of devices hosted on a remote server. Request: @@ -532,14 +517,13 @@ class TransportLayerClient(object): path = _create_v1_path("/user/keys/claim") - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def get_missing_events( + async def get_missing_events( self, destination, room_id, @@ -551,7 +535,7 @@ class TransportLayerClient(object): ): path = _create_v1_path("/get_missing_events/%s", room_id) - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data={ diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index ecdb12a7bf..0e2656ccb3 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -23,39 +23,32 @@ logger = logging.getLogger(__name__) def _create_rerouter(func_name): - """Returns a function that looks at the group id and calls the function + """Returns an async function that looks at the group id and calls the function on federation or the local group server if the group is local """ - def f(self, group_id, *args, **kwargs): + async def f(self, group_id, *args, **kwargs): if self.is_mine_id(group_id): - return getattr(self.groups_server_handler, func_name)( + return await getattr(self.groups_server_handler, func_name)( group_id, *args, **kwargs ) else: destination = get_domain_from_id(group_id) - d = getattr(self.transport_client, func_name)( - destination, group_id, *args, **kwargs - ) - # Capture errors returned by the remote homeserver and - # re-throw specific errors as SynapseErrors. This is so - # when the remote end responds with things like 403 Not - # In Group, we can communicate that to the client instead - # of a 500. - def http_response_errback(failure): - failure.trap(HttpResponseException) - e = failure.value + try: + return await getattr(self.transport_client, func_name)( + destination, group_id, *args, **kwargs + ) + except HttpResponseException as e: + # Capture errors returned by the remote homeserver and + # re-throw specific errors as SynapseErrors. This is so + # when the remote end responds with things like 403 Not + # In Group, we can communicate that to the client instead + # of a 500. raise e.to_synapse_error() - - def request_failed_errback(failure): - failure.trap(RequestSendFailed) + except RequestSendFailed: raise SynapseError(502, "Failed to contact group server") - d.addErrback(http_response_errback) - d.addErrback(request_failed_errback) - return d - return f diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index ea026ed9f4..2a6373937a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -121,8 +121,7 @@ class MatrixFederationRequest(object): return self.json -@defer.inlineCallbacks -def _handle_json_response(reactor, timeout_sec, request, response): +async def _handle_json_response(reactor, timeout_sec, request, response): """ Reads the JSON body of a response, with a timeout @@ -141,7 +140,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): d = treq.json_content(response) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) - body = yield make_deferred_yieldable(d) + body = await make_deferred_yieldable(d) except TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response", request.txn_id, request.destination, @@ -224,8 +223,7 @@ class MatrixFederationHttpClient(object): self._cooperator = Cooperator(scheduler=schedule) - @defer.inlineCallbacks - def _send_request_with_optional_trailing_slash( + async def _send_request_with_optional_trailing_slash( self, request, try_trailing_slash_on_400=False, **send_request_args ): """Wrapper for _send_request which can optionally retry the request @@ -246,10 +244,10 @@ class MatrixFederationHttpClient(object): (except 429). Returns: - Deferred[Dict]: Parsed JSON response body. + Dict: Parsed JSON response body. """ try: - response = yield self._send_request(request, **send_request_args) + response = await self._send_request(request, **send_request_args) except HttpResponseException as e: # Received an HTTP error > 300. Check if it meets the requirements # to retry with a trailing slash @@ -265,12 +263,11 @@ class MatrixFederationHttpClient(object): logger.info("Retrying request with trailing slash") request.path += "/" - response = yield self._send_request(request, **send_request_args) + response = await self._send_request(request, **send_request_args) return response - @defer.inlineCallbacks - def _send_request( + async def _send_request( self, request, retry_on_dns_fail=True, @@ -311,7 +308,7 @@ class MatrixFederationHttpClient(object): backoff_on_404 (bool): Back off if we get a 404 Returns: - Deferred[twisted.web.client.Response]: resolves with the HTTP + twisted.web.client.Response: resolves with the HTTP response object on success. Raises: @@ -335,7 +332,7 @@ class MatrixFederationHttpClient(object): ): raise FederationDeniedError(request.destination) - limiter = yield synapse.util.retryutils.get_retry_limiter( + limiter = await synapse.util.retryutils.get_retry_limiter( request.destination, self.clock, self._store, @@ -433,7 +430,7 @@ class MatrixFederationHttpClient(object): reactor=self.reactor, ) - response = yield request_deferred + response = await request_deferred except TimeoutError as e: raise RequestSendFailed(e, can_retry=True) from e except DNSLookupError as e: @@ -474,7 +471,7 @@ class MatrixFederationHttpClient(object): ) try: - body = yield make_deferred_yieldable(d) + body = await make_deferred_yieldable(d) except Exception as e: # Eh, we're already going to raise an exception so lets # ignore if this fails. @@ -528,7 +525,7 @@ class MatrixFederationHttpClient(object): delay, ) - yield self.clock.sleep(delay) + await self.clock.sleep(delay) retries_left -= 1 else: raise @@ -591,8 +588,7 @@ class MatrixFederationHttpClient(object): ) return auth_headers - @defer.inlineCallbacks - def put_json( + async def put_json( self, destination, path, @@ -636,7 +632,7 @@ class MatrixFederationHttpClient(object): enabled. Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -658,7 +654,7 @@ class MatrixFederationHttpClient(object): json=data, ) - response = yield self._send_request_with_optional_trailing_slash( + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, backoff_on_404=backoff_on_404, @@ -667,14 +663,13 @@ class MatrixFederationHttpClient(object): timeout=timeout, ) - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, self.default_timeout, request, response ) return body - @defer.inlineCallbacks - def post_json( + async def post_json( self, destination, path, @@ -707,7 +702,7 @@ class MatrixFederationHttpClient(object): args (dict): query params Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -725,7 +720,7 @@ class MatrixFederationHttpClient(object): method="POST", destination=destination, path=path, query=args, json=data ) - response = yield self._send_request( + response = await self._send_request( request, long_retries=long_retries, timeout=timeout, @@ -737,13 +732,12 @@ class MatrixFederationHttpClient(object): else: _sec_timeout = self.default_timeout - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, _sec_timeout, request, response ) return body - @defer.inlineCallbacks - def get_json( + async def get_json( self, destination, path, @@ -775,7 +769,7 @@ class MatrixFederationHttpClient(object): response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -792,7 +786,7 @@ class MatrixFederationHttpClient(object): method="GET", destination=destination, path=path, query=args ) - response = yield self._send_request_with_optional_trailing_slash( + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, backoff_on_404=False, @@ -801,14 +795,13 @@ class MatrixFederationHttpClient(object): timeout=timeout, ) - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, self.default_timeout, request, response ) return body - @defer.inlineCallbacks - def delete_json( + async def delete_json( self, destination, path, @@ -836,7 +829,7 @@ class MatrixFederationHttpClient(object): args (dict): query params Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -853,20 +846,19 @@ class MatrixFederationHttpClient(object): method="DELETE", destination=destination, path=path, query=args ) - response = yield self._send_request( + response = await self._send_request( request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, ) - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, self.default_timeout, request, response ) return body - @defer.inlineCallbacks - def get_file( + async def get_file( self, destination, path, @@ -886,7 +878,7 @@ class MatrixFederationHttpClient(object): and try the request anyway. Returns: - Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of + tuple[int, dict]: Resolves with an (int,dict) tuple of the file length and a dict of the response headers. Raises: @@ -903,7 +895,7 @@ class MatrixFederationHttpClient(object): method="GET", destination=destination, path=path, query=args ) - response = yield self._send_request( + response = await self._send_request( request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff ) @@ -912,7 +904,7 @@ class MatrixFederationHttpClient(object): try: d = _readBodyToFile(response, output_stream, max_size) d.addTimeout(self.default_timeout, self.reactor) - length = yield make_deferred_yieldable(d) + length = await make_deferred_yieldable(d) except Exception as e: logger.warning( "{%s} [%s] Error reading response: %s", diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index f9ce609923..e0ad8e8a77 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -102,11 +102,10 @@ class KeyringTestCase(unittest.HomeserverTestCase): } persp_deferred = defer.Deferred() - @defer.inlineCallbacks - def get_perspectives(**kwargs): + async def get_perspectives(**kwargs): self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): - yield persp_deferred + await persp_deferred return persp_resp self.http_client.post_json.side_effect = get_perspectives @@ -355,7 +354,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): } signedjson.sign.sign_json(response, SERVER_NAME, testkey) - def get_json(destination, path, **kwargs): + async def get_json(destination, path, **kwargs): self.assertEqual(destination, SERVER_NAME) self.assertEqual(path, "/_matrix/key/v2/server/key1") return response @@ -444,7 +443,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect a perspectives-server key query """ - def post_json(destination, path, data, **kwargs): + async def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) self.assertEqual(path, "/_matrix/key/v2/query") @@ -580,14 +579,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # remove the perspectives server's signature response = build_response() del response["signatures"][self.mock_perspective_server.server_name] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig") # remove the origin server's signature response = build_response() del response["signatures"][SERVER_NAME] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 5cd0510f0d..b8ca118716 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): @@ -78,9 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -109,9 +110,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -147,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) # Artificially raise the complexity @@ -203,9 +204,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -233,9 +234,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index d1bd18da39..5f512ff8bf 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): receipt = ReadReceipt( "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() mock_send_transaction.assert_not_called() diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 00bb776271..bc0c5aefdc 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -16,8 +16,6 @@ from mock import Mock -from twisted.internet import defer - import synapse import synapse.api.errors from synapse.api.constants import EventTypes @@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room from synapse.types import RoomAlias, create_requester from tests import unittest +from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): @@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 4f1347cd25..d70e1fc608 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -138,7 +139,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_other_name(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index fff4f0cbf4..ac598249e4 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase): @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: - fetch_d = self.cl.get_json("testserv:8008", "foo/bar") + fetch_d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar") + ) # Nothing happened yet self.assertNoResult(fetch_d) @@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase): """ If the DNS lookup returns an error, it will bubble up. """ - d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + ) self.pump() f = self.failureResultOf(d) @@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_connection_refused(self): - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request - d = cl.get_json("internal:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000)) # Nothing happened yet self.assertNoResult(d) @@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase): # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request - d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + ) # Nothing has happened yet self.assertNoResult(d) @@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request - d = cl.post_json("fine:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) @@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase): request = MatrixFederationRequest( method="GET", destination="testserv:8008", path="foo/bar" ) - d = self.cl._send_request(request, timeout=10000) + d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000)) self.pump() @@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) + defer.ensureDeferred( + self.cl.post_json( + "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} + ) + ) self.pump() @@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase): def test_closes_connection(self): """Check that the client closes unused HTTP connections""" - d = self.cl.get_json("testserv:8008", "foo/bar") + d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) self.pump() diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 8d4dbf232e..83f9aa291c 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -16,8 +16,6 @@ import logging from mock import Mock -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource @@ -25,6 +23,7 @@ from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.test_utils import make_awaitable logger = logging.getLogger(__name__) @@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", @@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index b1a4decced..0f1144fe1e 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.fetches = [] - def get_file(destination, path, output_stream, args=None, max_size=None): + async def get_file(destination, path, output_stream, args=None, max_size=None): """ Returns tuple[int,dict,str,int] of file length, response headers, absolute URI, and response code. @@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): d = Deferred() d.addCallback(write_to) self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) + return await make_deferred_yieldable(d) client = Mock() client.get_file = get_file diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 99eb477149..6850c666be 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect an outgoing GET request for the given key """ - def get_json(destination, path, ignore_backoff=False, **kwargs): + async def get_json(destination, path, ignore_backoff=False, **kwargs): self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 - def post_json(destination, path, data): + async def post_json(destination, path, data): self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, "/_matrix/key/v2/query", diff --git a/tests/test_federation.py b/tests/test_federation.py index 87a16d7d7a..c2f12c2741 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): prev_events that said event references. """ - def post_json(destination, path, data, headers=None, timeout=0): + async def post_json(destination, path, data, headers=None, timeout=0): # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} -- cgit 1.5.1 From 18de00adb4471a55b504f4afb9f29facf0a51785 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 31 Jul 2020 14:34:42 +0100 Subject: Add ratelimiting on joins --- docs/sample_config.yaml | 12 ++++++++++++ synapse/config/ratelimiting.py | 21 +++++++++++++++++++++ synapse/handlers/room_member.py | 37 +++++++++++++++++++++++++++++++++++-- tests/utils.py | 4 ++++ 4 files changed, 72 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b21e36bb6d..fef503479e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -731,6 +731,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # - one for ratelimiting redactions by room admins. If this is not explicitly # set then it uses the same ratelimiting as per rc_message. This is useful # to allow room admins to deal with abuse quickly. +# - two for ratelimiting number of rooms a user can join, "local" for when +# users are joining rooms the server is already in (this is cheap) vs +# "remote" for when users are trying to join rooms not on the server (which +# can be more expensive) # # The defaults are as shown below. # @@ -756,6 +760,14 @@ log_config: "CONFDIR/SERVERNAME.log.config" #rc_admin_redaction: # per_second: 1 # burst_count: 50 +# +#rc_joins: +# local: +# per_second: 0.1 +# burst_count: 3 +# remote: +# per_second: 0.01 +# burst_count: 3 # Ratelimiting settings for incoming federation diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 2dd94bae2b..b2c78ac40c 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -93,6 +93,15 @@ class RatelimitConfig(Config): if rc_admin_redaction: self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction) + self.rc_joins_local = RateLimitConfig( + config.get("rc_joins", {}).get("local", {}), + defaults={"per_second": 0.1, "burst_count": 3}, + ) + self.rc_joins_remote = RateLimitConfig( + config.get("rc_joins", {}).get("remote", {}), + defaults={"per_second": 0.01, "burst_count": 3}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -118,6 +127,10 @@ class RatelimitConfig(Config): # - one for ratelimiting redactions by room admins. If this is not explicitly # set then it uses the same ratelimiting as per rc_message. This is useful # to allow room admins to deal with abuse quickly. + # - two for ratelimiting number of rooms a user can join, "local" for when + # users are joining rooms the server is already in (this is cheap) vs + # "remote" for when users are trying to join rooms not on the server (which + # can be more expensive) # # The defaults are as shown below. # @@ -143,6 +156,14 @@ class RatelimitConfig(Config): #rc_admin_redaction: # per_second: 1 # burst_count: 50 + # + #rc_joins: + # local: + # per_second: 0.1 + # burst_count: 3 + # remote: + # per_second: 0.01 + # burst_count: 3 # Ratelimiting settings for incoming federation diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a1a8fa1d3b..822ca9da6a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -22,7 +22,8 @@ from unpaddedbase64 import encode_base64 from synapse import types from synapse.api.constants import MAX_DEPTH, EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import EventFormatVersions from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase @@ -77,6 +78,17 @@ class RoomMemberHandler(object): if self._is_on_event_persistence_instance: self.persist_event_storage = hs.get_storage().persistence + self._join_rate_limiter_local = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, + burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, + ) + self._join_rate_limiter_remote = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, + burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -441,7 +453,28 @@ class RoomMemberHandler(object): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - if not is_host_in_room: + if is_host_in_room: + time_now_s = self.clock.time() + allowed, time_allowed = self._join_rate_limiter_local.can_do_action( + requester.user.to_string(), + ) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now_s)) + ) + + else: + time_now_s = self.clock.time() + allowed, time_allowed = self._join_rate_limiter_remote.can_do_action( + requester.user.to_string(), + ) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now_s)) + ) + inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) diff --git a/tests/utils.py b/tests/utils.py index ac643679aa..a8e85436f9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -154,6 +154,10 @@ def default_config(name, parse=False): "account": {"per_second": 10000, "burst_count": 10000}, "failed_attempts": {"per_second": 10000, "burst_count": 10000}, }, + "rc_joins": { + "local": {"per_second": 10000, "burst_count": 10000}, + "remote": {"per_second": 10000, "burst_count": 10000}, + }, "saml2_enabled": False, "public_baseurl": None, "default_identity_server": None, -- cgit 1.5.1 From 2a89ce8cd4d563ef22995882e9548f1aff3e42f1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 3 Aug 2020 08:29:01 -0400 Subject: Convert the crypto module to async/await. (#8003) --- changelog.d/8003.misc | 1 + synapse/crypto/keyring.py | 201 ++++++++++++++++++++----------------------- tests/crypto/test_keyring.py | 39 ++++----- 3 files changed, 109 insertions(+), 132 deletions(-) create mode 100644 changelog.d/8003.misc (limited to 'tests') diff --git a/changelog.d/8003.misc b/changelog.d/8003.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8003.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 443cde0b6d..28ef7cfdb9 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -223,8 +223,7 @@ class Keyring(object): return results - @defer.inlineCallbacks - def _start_key_lookups(self, verify_requests): + async def _start_key_lookups(self, verify_requests): """Sets off the key fetches for each verify request Once each fetch completes, verify_request.key_ready will be resolved. @@ -245,7 +244,7 @@ class Keyring(object): server_to_request_ids.setdefault(server_name, set()).add(request_id) # Wait for any previous lookups to complete before proceeding. - yield self.wait_for_previous_lookups(server_to_request_ids.keys()) + await self.wait_for_previous_lookups(server_to_request_ids.keys()) # take out a lock on each of the servers by sticking a Deferred in # key_downloads @@ -283,15 +282,14 @@ class Keyring(object): except Exception: logger.exception("Error starting key lookups") - @defer.inlineCallbacks - def wait_for_previous_lookups(self, server_names): + async def wait_for_previous_lookups(self, server_names) -> None: """Waits for any previous key lookups for the given servers to finish. Args: server_names (Iterable[str]): list of servers which we want to look up Returns: - Deferred[None]: resolves once all key lookups for the given servers have + Resolves once all key lookups for the given servers have completed. Follows the synapse rules of logcontext preservation. """ loop_count = 1 @@ -309,7 +307,7 @@ class Keyring(object): loop_count, ) with PreserveLoggingContext(): - yield defer.DeferredList((w[1] for w in wait_on)) + await defer.DeferredList((w[1] for w in wait_on)) loop_count += 1 @@ -326,44 +324,44 @@ class Keyring(object): remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} - @defer.inlineCallbacks - def do_iterations(): - with Measure(self.clock, "get_server_verify_keys"): - for f in self._key_fetchers: - if not remaining_requests: - return - yield self._attempt_key_fetches_with_fetcher(f, remaining_requests) + async def do_iterations(): + try: + with Measure(self.clock, "get_server_verify_keys"): + for f in self._key_fetchers: + if not remaining_requests: + return + await self._attempt_key_fetches_with_fetcher( + f, remaining_requests + ) - # look for any requests which weren't satisfied + # look for any requests which weren't satisfied + with PreserveLoggingContext(): + for verify_request in remaining_requests: + verify_request.key_ready.errback( + SynapseError( + 401, + "No key for %s with ids in %s (min_validity %i)" + % ( + verify_request.server_name, + verify_request.key_ids, + verify_request.minimum_valid_until_ts, + ), + Codes.UNAUTHORIZED, + ) + ) + except Exception as err: + # we don't really expect to get here, because any errors should already + # have been caught and logged. But if we do, let's log the error and make + # sure that all of the deferreds are resolved. + logger.error("Unexpected error in _get_server_verify_keys: %s", err) with PreserveLoggingContext(): for verify_request in remaining_requests: - verify_request.key_ready.errback( - SynapseError( - 401, - "No key for %s with ids in %s (min_validity %i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ), - Codes.UNAUTHORIZED, - ) - ) - - def on_err(err): - # we don't really expect to get here, because any errors should already - # have been caught and logged. But if we do, let's log the error and make - # sure that all of the deferreds are resolved. - logger.error("Unexpected error in _get_server_verify_keys: %s", err) - with PreserveLoggingContext(): - for verify_request in remaining_requests: - if not verify_request.key_ready.called: - verify_request.key_ready.errback(err) + if not verify_request.key_ready.called: + verify_request.key_ready.errback(err) - run_in_background(do_iterations).addErrback(on_err) + run_in_background(do_iterations) - @defer.inlineCallbacks - def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): + async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): """Use a key fetcher to attempt to satisfy some key requests Args: @@ -390,7 +388,7 @@ class Keyring(object): verify_request.minimum_valid_until_ts, ) - results = yield fetcher.get_keys(missing_keys) + results = await fetcher.get_keys(missing_keys) completed = [] for verify_request in remaining_requests: @@ -423,7 +421,7 @@ class Keyring(object): class KeyFetcher(object): - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """ Args: keys_to_fetch (dict[str, dict[str, int]]): @@ -442,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """see KeyFetcher.get_keys""" keys_to_fetch = ( @@ -452,7 +449,7 @@ class StoreKeyFetcher(KeyFetcher): for key_id in keys_for_server.keys() ) - res = yield self.store.get_server_verify_keys(keys_to_fetch) + res = await self.store.get_server_verify_keys(keys_to_fetch) keys = {} for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key @@ -464,8 +461,7 @@ class BaseV2KeyFetcher(object): self.store = hs.get_datastore() self.config = hs.get_config() - @defer.inlineCallbacks - def process_v2_response(self, from_server, response_json, time_added_ms): + async def process_v2_response(self, from_server, response_json, time_added_ms): """Parse a 'Server Keys' structure from the result of a /key request This is used to parse either the entirety of the response from @@ -537,7 +533,7 @@ class BaseV2KeyFetcher(object): key_json_bytes = encode_canonical_json(response_json) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -567,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): self.client = hs.get_http_client() self.key_servers = self.config.key_servers - @defer.inlineCallbacks - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """see KeyFetcher.get_keys""" - @defer.inlineCallbacks - def get_key(key_server): + async def get_key(key_server): try: - result = yield self.get_server_verify_key_v2_indirect( + result = await self.get_server_verify_key_v2_indirect( keys_to_fetch, key_server ) return result @@ -592,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return {} - results = yield make_deferred_yieldable( + results = await make_deferred_yieldable( defer.gatherResults( [run_in_background(get_key, server) for server in self.key_servers], consumeErrors=True, @@ -606,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return union_of_keys - @defer.inlineCallbacks - def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): + async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): """ Args: keys_to_fetch (dict[str, dict[str, int]]): @@ -617,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): the keys Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map + dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map from server_name -> key_id -> FetchKeyResult Raises: @@ -632,20 +625,18 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) try: - query_response = yield defer.ensureDeferred( - self.client.post_json( - destination=perspective_name, - path="/_matrix/key/v2/query", - data={ - "server_keys": { - server_name: { - key_id: {"minimum_valid_until_ts": min_valid_ts} - for key_id, min_valid_ts in server_keys.items() - } - for server_name, server_keys in keys_to_fetch.items() + query_response = await self.client.post_json( + destination=perspective_name, + path="/_matrix/key/v2/query", + data={ + "server_keys": { + server_name: { + key_id: {"minimum_valid_until_ts": min_valid_ts} + for key_id, min_valid_ts in server_keys.items() } - }, - ) + for server_name, server_keys in keys_to_fetch.items() + } + }, ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve upon @@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): try: self._validate_perspectives_response(key_server, response) - processed_response = yield self.process_v2_response( + processed_response = await self.process_v2_response( perspective_name, response, time_added_ms=time_now_ms ) except KeyLookupError as e: @@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) keys.setdefault(server_name, {}).update(processed_response) - yield self.store.store_server_verify_keys( + await self.store.store_server_verify_keys( perspective_name, time_now_ms, added_keys ) @@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher): self.clock = hs.get_clock() self.client = hs.get_http_client() - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """ Args: keys_to_fetch (dict[str, iterable[str]]): the keys to be fetched. server_name -> key_ids Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: + dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]: map from server_name -> key_id -> FetchKeyResult """ results = {} - @defer.inlineCallbacks - def get_key(key_to_fetch_item): + async def get_key(key_to_fetch_item): server_name, key_ids = key_to_fetch_item try: - keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids) + keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) results[server_name] = keys except KeyLookupError as e: logger.warning( @@ -767,12 +757,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except Exception: logger.exception("Error getting keys %s from %s", key_ids, server_name) - return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback( - lambda _: results - ) + return await yieldable_gather_results( + get_key, keys_to_fetch.items() + ).addCallback(lambda _: results) - @defer.inlineCallbacks - def get_server_verify_key_v2_direct(self, server_name, key_ids): + async def get_server_verify_key_v2_direct(self, server_name, key_ids): """ Args: @@ -794,25 +783,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher): time_now_ms = self.clock.time_msec() try: - response = yield defer.ensureDeferred( - self.client.get_json( - destination=server_name, - path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id), - ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an - # easy request to handle, so if it doesn't reply within 10s, it's - # probably not going to. - # - # Furthermore, when we are acting as a notary server, we cannot - # wait all day for all of the origin servers, as the requesting - # server will otherwise time out before we can respond. - # - # (Note that get_json may make 4 attempts, so this can still take - # almost 45 seconds to fetch the headers, plus up to another 60s to - # read the response). - timeout=10000, - ) + response = await self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server/" + + urllib.parse.quote(requested_key_id), + ignore_backoff=True, + # we only give the remote server 10s to respond. It should be an + # easy request to handle, so if it doesn't reply within 10s, it's + # probably not going to. + # + # Furthermore, when we are acting as a notary server, we cannot + # wait all day for all of the origin servers, as the requesting + # server will otherwise time out before we can respond. + # + # (Note that get_json may make 4 attempts, so this can still take + # almost 45 seconds to fetch the headers, plus up to another 60s to + # read the response). + timeout=10000, ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve @@ -827,12 +814,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher): % (server_name, response["server_name"]) ) - response_keys = yield self.process_v2_response( + response_keys = await self.process_v2_response( from_server=server_name, response_json=response, time_added_ms=time_now_ms, ) - yield self.store.store_server_verify_keys( + await self.store.store_server_verify_keys( server_name, time_now_ms, ((server_name, key_id, key) for key_id, key in response_keys.items()), @@ -842,22 +829,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher): return keys -@defer.inlineCallbacks -def _handle_key_deferred(verify_request): +async def _handle_key_deferred(verify_request) -> None: """Waits for the key to become available, and then performs a verification Args: verify_request (VerifyJsonRequest): - Returns: - Deferred[None] - Raises: SynapseError if there was a problem performing the verification """ server_name = verify_request.server_name with PreserveLoggingContext(): - _, key_id, verify_key = yield verify_request.key_ready + _, key_id, verify_key = await verify_request.key_ready json_object = verify_request.json_object diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index e0ad8e8a77..0d4b05304b 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -40,6 +40,7 @@ from synapse.logging.context import ( from synapse.storage.keys import FetchKeyResult from tests import unittest +from tests.test_utils import make_awaitable class MockPerspectiveServer(object): @@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): with a null `ts_valid_until_ms` """ mock_fetcher = keyring.KeyFetcher() - mock_fetcher.get_keys = Mock(return_value=defer.succeed({})) + mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) kr = keyring.Keyring( self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) @@ -244,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase): """Two requests for the same key should be deduped.""" key1 = signedjson.key.generate_signing_key(1) - def get_keys(keys_to_fetch): + async def get_keys(keys_to_fetch): # there should only be one request object (with the max validity) self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } + return { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) } - ) + } mock_fetcher = keyring.KeyFetcher() mock_fetcher.get_keys = Mock(side_effect=get_keys) @@ -281,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase): """If the first fetcher cannot provide a recent enough key, we fall back""" key1 = signedjson.key.generate_signing_key(1) - def get_keys1(keys_to_fetch): + async def get_keys1(keys_to_fetch): self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800) - } - } - ) + return { + "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)} + } - def get_keys2(keys_to_fetch): + async def get_keys2(keys_to_fetch): self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } + return { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) } - ) + } mock_fetcher1 = keyring.KeyFetcher() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) -- cgit 1.5.1 From 6812509807a914f1a709d4db2f7adb0bd6e58cc5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 3 Aug 2020 08:45:42 -0400 Subject: Implement handling of HTTP HEAD requests. (#7999) --- changelog.d/7999.bugfix | 1 + synapse/http/server.py | 16 +++++++++++----- tests/test_server.py | 45 ++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 54 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7999.bugfix (limited to 'tests') diff --git a/changelog.d/7999.bugfix b/changelog.d/7999.bugfix new file mode 100644 index 0000000000..e0b8c4922f --- /dev/null +++ b/changelog.d/7999.bugfix @@ -0,0 +1 @@ +Fix a long standing bug where HTTP HEAD requests resulted in a 400 error. diff --git a/synapse/http/server.py b/synapse/http/server.py index d4f9ad6e67..94ab29974a 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -242,10 +242,12 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): no appropriate method exists. Can be overriden in sub classes for different routing. """ + # Treat HEAD requests as GET requests. + request_method = request.method.decode("ascii") + if request_method == "HEAD": + request_method = "GET" - method_handler = getattr( - self, "_async_render_%s" % (request.method.decode("ascii"),), None - ) + method_handler = getattr(self, "_async_render_%s" % (request_method,), None) if method_handler: raw_callback_return = method_handler(request) @@ -362,11 +364,15 @@ class JsonResource(DirectServeJsonResource): A tuple of the callback to use, the name of the servlet, and the key word arguments to pass to the callback """ + # Treat HEAD requests as GET requests. request_path = request.path.decode("ascii") + request_method = request.method + if request_method == b"HEAD": + request_method = b"GET" # Loop through all the registered callbacks to check if the method # and path regex match - for path_entry in self.path_regexs.get(request.method, []): + for path_entry in self.path_regexs.get(request_method, []): m = path_entry.pattern.match(request_path) if m: # We found a match! @@ -579,7 +585,7 @@ def set_cors_headers(request: Request): """ request.setHeader(b"Access-Control-Allow-Origin", b"*") request.setHeader( - b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS" + b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" ) request.setHeader( b"Access-Control-Allow-Headers", diff --git a/tests/test_server.py b/tests/test_server.py index 073b2362cc..d628070e48 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -157,6 +157,29 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + def test_head_request(self): + """ + JsonResource.handler_for_request gives correctly decoded URL args to + the callback, while Twisted will give the raw bytes of URL query + arguments. + """ + + def _callback(request, **kwargs): + return 200, {"result": True} + + res = JsonResource(self.homeserver) + res.register_paths( + "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet", + ) + + # The path was registered as GET, but this is a HEAD request. + request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"200") + self.assertNotIn("body", channel.result) + self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"]) + class OptionsResourceTests(unittest.TestCase): def setUp(self): @@ -255,7 +278,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.reactor = ThreadedMemoryReactorClock() def test_good_response(self): - def callback(request): + async def callback(request): request.write(b"response") request.finish() @@ -275,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): with the right location. """ - def callback(request, **kwargs): + async def callback(request, **kwargs): raise RedirectException(b"/look/an/eagle", 301) res = WrapHtmlRequestHandlerTests.TestResource() @@ -295,7 +318,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): returned too """ - def callback(request, **kwargs): + async def callback(request, **kwargs): e = RedirectException(b"/no/over/there", 304) e.cookies.append(b"session=yespls") raise e @@ -312,3 +335,19 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.assertEqual(location_headers, [b"/no/over/there"]) cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] self.assertEqual(cookies_headers, [b"session=yespls"]) + + def test_head_request(self): + """A head request should work by being turned into a GET request.""" + + async def callback(request): + request.write(b"response") + request.finish() + + res = WrapHtmlRequestHandlerTests.TestResource() + res.callback = callback + + request, channel = make_request(self.reactor, b"HEAD", b"/path") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"200") + self.assertNotIn("body", channel.result) -- cgit 1.5.1 From 5d92a1428ceb4077801afc1785a5472e89fd9df3 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 3 Aug 2020 13:54:24 -0700 Subject: Prevent join->join membership transitions changing member count (#7977) `StatsHandler` handles updates to the `current_state_delta_stream`, and updates room stats such as the amount of state events, joined users, etc. However, it counts every new join membership as a new user entering a room (and that user being in another room), whereas it's possible for a user's membership status to go from join -> join, for instance when they change their per-room profile information. This PR adds a check for join->join membership transitions, and bails out early, as none of the further checks are necessary at that point. Due to this bug, membership stats in many rooms have ended up being wildly larger than their true values. I am not sure if we also want to include a migration step which recalculates these statistics (possibly using the `_populate_stats_process_rooms` bg update). Bug introduced in the initial implementation https://github.com/matrix-org/synapse/pull/4338. --- changelog.d/7977.bugfix | 1 + synapse/handlers/stats.py | 2 +- .../main/schema/delta/58/12room_stats.sql | 32 +++++++++++++++ synapse/storage/data_stores/main/stats.py | 34 +++++++++++++--- tests/handlers/test_stats.py | 46 +++++++++++++++++++--- tests/rest/client/v1/utils.py | 24 ++++++++++- 6 files changed, 126 insertions(+), 13 deletions(-) create mode 100644 changelog.d/7977.bugfix create mode 100644 synapse/storage/data_stores/main/schema/delta/58/12room_stats.sql (limited to 'tests') diff --git a/changelog.d/7977.bugfix b/changelog.d/7977.bugfix new file mode 100644 index 0000000000..c587f13055 --- /dev/null +++ b/changelog.d/7977.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.7.2 which caused inaccurate membership counts in the room directory. diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 149f861239..249ffe2a55 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -232,7 +232,7 @@ class StatsHandler: if membership == prev_membership: pass # noop - if membership == Membership.JOIN: + elif membership == Membership.JOIN: room_stats_delta["joined_members"] += 1 elif membership == Membership.INVITE: room_stats_delta["invited_members"] += 1 diff --git a/synapse/storage/data_stores/main/schema/delta/58/12room_stats.sql b/synapse/storage/data_stores/main/schema/delta/58/12room_stats.sql new file mode 100644 index 0000000000..cade5dcca8 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/12room_stats.sql @@ -0,0 +1,32 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Recalculate the stats for all rooms after the fix to joined_members erroneously +-- incrementing on per-room profile changes. + +-- Note that the populate_stats_process_rooms background update is already set to +-- run if you're upgrading from Synapse <1.0.0. + +-- Additionally, if you've upgraded to v1.18.0 (which doesn't include this fix), +-- this bg job runs, and then update to v1.19.0, you'd end up with only half of +-- your rooms having room stats recalculated after this fix was in place. + +-- So we've switched the old `populate_stats_process_rooms` background job to a +-- no-op, and then kick off a bg job with a new name, but with the same +-- functionality as the old one. This effectively restarts the background job +-- from the beginning, without running it twice in a row, supporting both +-- upgrade usecases. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_stats_process_rooms_2', '{}'); diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 922400a7c3..40db8f594e 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -72,6 +72,9 @@ class StatsStore(StateDeltasStore): self.db.updates.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) + self.db.updates.register_background_update_handler( + "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2 + ) self.db.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) @@ -140,11 +143,30 @@ class StatsStore(StateDeltasStore): return len(users_to_work_on) async def _populate_stats_process_rooms(self, progress, batch_size): + """ + This was a background update which regenerated statistics for rooms. + + It has been replaced by StatsStore._populate_stats_process_rooms_2. This background + job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure + someone upgrading from None: + """ + Send a membership state event into a room. + + Args: + room: The ID of the room to send to + src: The mxid of the event sender + targ: The mxid of the event's target. The state key + membership: The type of membership event + extra_data: Extra information to include in the content of the event + tok: The user access token to use + expect_code: The expected HTTP response code + """ temp_id = self.auth_user_id self.auth_user_id = src @@ -97,6 +118,7 @@ class RestHelper(object): path = path + "?access_token=%s" % tok data = {"membership": membership} + data.update(extra_data) request, channel = make_request( self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") -- cgit 1.5.1 From e19de43eb5903c3b6ccca82334971ebc57fc38de Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Aug 2020 07:21:47 -0400 Subject: Convert streams to async. (#8014) --- changelog.d/8014.misc | 1 + synapse/handlers/initial_sync.py | 4 ++-- synapse/handlers/pagination.py | 2 +- synapse/handlers/room.py | 10 +++++----- synapse/handlers/search.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/notifier.py | 4 ++-- synapse/storage/data_stores/main/stream.py | 8 ++++---- synapse/streams/events.py | 22 +++++++++------------- .../test_resource_limits_server_notices.py | 2 +- 10 files changed, 27 insertions(+), 30 deletions(-) create mode 100644 changelog.d/8014.misc (limited to 'tests') diff --git a/changelog.d/8014.misc b/changelog.d/8014.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8014.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index f88bad5f25..ae6bd1d352 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -109,7 +109,7 @@ class InitialSyncHandler(BaseHandler): rooms_ret = [] - now_token = await self.hs.get_event_sources().get_current_token() + now_token = self.hs.get_event_sources().get_current_token() presence_stream = self.hs.get_event_sources().sources["presence"] pagination_config = PaginationConfig(from_token=now_token) @@ -360,7 +360,7 @@ class InitialSyncHandler(BaseHandler): current_state.values(), time_now ) - now_token = await self.hs.get_event_sources().get_current_token() + now_token = self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None if limit is None: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index da06582d4b..487420bb5d 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -309,7 +309,7 @@ class PaginationHandler(object): room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - await self.hs.get_event_sources().get_current_token_for_pagination() + self.hs.get_event_sources().get_current_token_for_pagination() ) room_token = pagin_config.from_token.room_key diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0c5b99234d..a8545255b1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -22,7 +22,7 @@ import logging import math import string from collections import OrderedDict -from typing import Optional, Tuple +from typing import Awaitable, Optional, Tuple from synapse.api.constants import ( EventTypes, @@ -1041,7 +1041,7 @@ class RoomEventSource(object): ): # We just ignore the key for now. - to_key = await self.get_current_key() + to_key = self.get_current_key() from_token = RoomStreamToken.parse(from_key) if from_token.topological: @@ -1081,10 +1081,10 @@ class RoomEventSource(object): return (events, end_key) - def get_current_key(self): - return self.store.get_room_events_max_id() + def get_current_key(self) -> str: + return "s%d" % (self.store.get_room_max_stream_ordering(),) - def get_current_key_for_room(self, room_id): + def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: return self.store.get_room_events_max_id(room_id) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9b312a1558..d58f9788c5 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -340,7 +340,7 @@ class SearchHandler(BaseHandler): # If client has asked for "context" for each event (i.e. some surrounding # events and state), fetch that if event_context is not None: - now_token = await self.hs.get_event_sources().get_current_token() + now_token = self.hs.get_event_sources().get_current_token() contexts = {} for event in allowed_events: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index eaa4eeadf7..5a19bac929 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -961,7 +961,7 @@ class SyncHandler(object): # this is due to some of the underlying streams not supporting the ability # to query up to a given point. # Always use the `now_token` in `SyncResultBuilder` - now_token = await self.event_sources.get_current_token() + now_token = self.event_sources.get_current_token() logger.debug( "Calculating sync response for %r between %s and %s", diff --git a/synapse/notifier.py b/synapse/notifier.py index bd41f77852..22ab4a9da5 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -320,7 +320,7 @@ class Notifier(object): """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - current_token = await self.event_sources.get_current_token() + current_token = self.event_sources.get_current_token() if room_ids is None: room_ids = await self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( @@ -397,7 +397,7 @@ class Notifier(object): """ from_token = pagination_config.from_token if not from_token: - from_token = await self.event_sources.get_current_token() + from_token = self.event_sources.get_current_token() limit = pagination_config.limit diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 10d39b3699..f1334a6efc 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -39,6 +39,7 @@ what sort order was used: import abc import logging from collections import namedtuple +from typing import Optional from twisted.internet import defer @@ -557,19 +558,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return self.db.runInteraction("get_room_event_before_stream_ordering", _f) - @defer.inlineCallbacks - def get_room_events_max_id(self, room_id=None): + async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: """Returns the current token for rooms stream. By default, it returns the current global stream token. Specifying a `room_id` causes it to return the current room specific topological token. """ - token = yield self.get_room_max_stream_ordering() + token = self.get_room_max_stream_ordering() if room_id is None: return "s%d" % (token,) else: - topo = yield self.db.runInteraction( + topo = await self.db.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) return "t%d-%d" % (topo, token) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 5d3eddcfdc..393e34b9fb 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -15,8 +15,6 @@ from typing import Any, Dict -from twisted.internet import defer - from synapse.handlers.account_data import AccountDataEventSource from synapse.handlers.presence import PresenceEventSource from synapse.handlers.receipts import ReceiptEventSource @@ -40,19 +38,18 @@ class EventSources(object): } # type: Dict[str, Any] self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_current_token(self): + def get_current_token(self) -> StreamToken: push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() groups_key = self.store.get_group_stream_token() token = StreamToken( - room_key=(yield self.sources["room"].get_current_key()), - presence_key=(yield self.sources["presence"].get_current_key()), - typing_key=(yield self.sources["typing"].get_current_key()), - receipt_key=(yield self.sources["receipt"].get_current_key()), - account_data_key=(yield self.sources["account_data"].get_current_key()), + room_key=self.sources["room"].get_current_key(), + presence_key=self.sources["presence"].get_current_key(), + typing_key=self.sources["typing"].get_current_key(), + receipt_key=self.sources["receipt"].get_current_key(), + account_data_key=self.sources["account_data"].get_current_key(), push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, @@ -60,8 +57,7 @@ class EventSources(object): ) return token - @defer.inlineCallbacks - def get_current_token_for_pagination(self): + def get_current_token_for_pagination(self) -> StreamToken: """Get the current token for a given room to be used to paginate events. @@ -69,10 +65,10 @@ class EventSources(object): than `room`, since they are not used during pagination. Returns: - Deferred[StreamToken] + The current token for pagination. """ token = StreamToken( - room_key=(yield self.sources["room"].get_current_key()), + room_key=self.sources["room"].get_current_key(), presence_key=0, typing_key=0, receipt_key=0, diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 99908edba3..7f70353b0d 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -275,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id) ) - token = self.get_success(self.event_source.get_current_token()) + token = self.event_source.get_current_token() events, _ = self.get_success( self.store.get_recent_events_for_room( room_id, limit=100, end_token=token.room_key -- cgit 1.5.1 From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- changelog.d/8033.misc | 1 + docs/user_directory.md | 2 +- scripts-dev/update_database | 4 +- scripts/synapse_port_db | 78 +- synapse/app/_base.py | 2 +- synapse/app/generic_worker.py | 14 +- synapse/app/homeserver.py | 6 +- synapse/config/database.py | 5 +- synapse/events/snapshot.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/presence.py | 4 +- synapse/module_api/__init__.py | 2 +- synapse/replication/slave/storage/_base.py | 6 +- synapse/replication/slave/storage/account_data.py | 8 +- synapse/replication/slave/storage/appservice.py | 2 +- synapse/replication/slave/storage/client_ips.py | 6 +- synapse/replication/slave/storage/deviceinbox.py | 6 +- synapse/replication/slave/storage/devices.py | 8 +- synapse/replication/slave/storage/directory.py | 2 +- synapse/replication/slave/storage/events.py | 24 +- synapse/replication/slave/storage/filtering.py | 6 +- synapse/replication/slave/storage/groups.py | 6 +- synapse/replication/slave/storage/keys.py | 2 +- synapse/replication/slave/storage/presence.py | 6 +- synapse/replication/slave/storage/profile.py | 2 +- synapse/replication/slave/storage/push_rule.py | 2 +- synapse/replication/slave/storage/pushers.py | 6 +- synapse/replication/slave/storage/receipts.py | 6 +- synapse/replication/slave/storage/registration.py | 2 +- synapse/replication/slave/storage/room.py | 6 +- synapse/replication/slave/storage/transactions.py | 2 +- synapse/rest/admin/rooms.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/server.py | 4 +- synapse/state/__init__.py | 2 +- synapse/storage/__init__.py | 15 +- synapse/storage/_base.py | 6 +- synapse/storage/background_updates.py | 22 +- synapse/storage/data_stores/__init__.py | 97 - synapse/storage/data_stores/main/__init__.py | 592 ------ synapse/storage/data_stores/main/account_data.py | 430 ----- synapse/storage/data_stores/main/appservice.py | 372 ---- synapse/storage/data_stores/main/cache.py | 307 --- synapse/storage/data_stores/main/censor_events.py | 208 -- synapse/storage/data_stores/main/client_ips.py | 576 ------ synapse/storage/data_stores/main/deviceinbox.py | 476 ----- synapse/storage/data_stores/main/devices.py | 1309 ------------- synapse/storage/data_stores/main/directory.py | 195 -- synapse/storage/data_stores/main/e2e_room_keys.py | 439 ----- .../storage/data_stores/main/end_to_end_keys.py | 746 -------- .../storage/data_stores/main/event_federation.py | 724 ------- .../storage/data_stores/main/event_push_actions.py | 883 --------- synapse/storage/data_stores/main/events.py | 1521 --------------- .../storage/data_stores/main/events_bg_updates.py | 585 ------ synapse/storage/data_stores/main/events_worker.py | 1454 -------------- synapse/storage/data_stores/main/filtering.py | 74 - synapse/storage/data_stores/main/group_server.py | 1295 ------------- synapse/storage/data_stores/main/keys.py | 208 -- .../storage/data_stores/main/media_repository.py | 394 ---- synapse/storage/data_stores/main/metrics.py | 128 -- .../data_stores/main/monthly_active_users.py | 359 ---- synapse/storage/data_stores/main/openid.py | 33 - synapse/storage/data_stores/main/presence.py | 186 -- synapse/storage/data_stores/main/profile.py | 178 -- synapse/storage/data_stores/main/purge_events.py | 400 ---- synapse/storage/data_stores/main/push_rule.py | 759 -------- synapse/storage/data_stores/main/pusher.py | 354 ---- synapse/storage/data_stores/main/receipts.py | 589 ------ synapse/storage/data_stores/main/registration.py | 1582 ---------------- synapse/storage/data_stores/main/rejections.py | 31 - synapse/storage/data_stores/main/relations.py | 327 ---- synapse/storage/data_stores/main/room.py | 1425 -------------- synapse/storage/data_stores/main/roommember.py | 1135 ----------- .../data_stores/main/schema/delta/12/v12.sql | 63 - .../data_stores/main/schema/delta/13/v13.sql | 19 - .../data_stores/main/schema/delta/14/v14.sql | 23 - .../main/schema/delta/15/appservice_txns.sql | 31 - .../main/schema/delta/15/presence_indices.sql | 2 - .../data_stores/main/schema/delta/15/v15.sql | 24 - .../main/schema/delta/16/events_order_index.sql | 4 - .../schema/delta/16/remote_media_cache_index.sql | 2 - .../main/schema/delta/16/remove_duplicates.sql | 9 - .../main/schema/delta/16/room_alias_index.sql | 3 - .../main/schema/delta/16/unique_constraints.sql | 72 - .../data_stores/main/schema/delta/16/users.sql | 56 - .../main/schema/delta/17/drop_indexes.sql | 18 - .../main/schema/delta/17/server_keys.sql | 24 - .../main/schema/delta/17/user_threepids.sql | 9 - .../schema/delta/18/server_keys_bigger_ints.sql | 32 - .../main/schema/delta/19/event_index.sql | 19 - .../data_stores/main/schema/delta/20/dummy.sql | 1 - .../data_stores/main/schema/delta/20/pushers.py | 88 - .../main/schema/delta/21/end_to_end_keys.sql | 34 - .../data_stores/main/schema/delta/21/receipts.sql | 38 - .../main/schema/delta/22/receipts_index.sql | 22 - .../main/schema/delta/22/user_threepids_unique.sql | 19 - .../main/schema/delta/24/stats_reporting.sql | 18 - .../data_stores/main/schema/delta/25/fts.py | 80 - .../main/schema/delta/25/guest_access.sql | 25 - .../main/schema/delta/25/history_visibility.sql | 25 - .../data_stores/main/schema/delta/25/tags.sql | 38 - .../main/schema/delta/26/account_data.sql | 17 - .../main/schema/delta/27/account_data.sql | 36 - .../main/schema/delta/27/forgotten_memberships.sql | 26 - .../storage/data_stores/main/schema/delta/27/ts.py | 59 - .../main/schema/delta/28/event_push_actions.sql | 27 - .../main/schema/delta/28/events_room_stream.sql | 20 - .../main/schema/delta/28/public_roms_index.sql | 20 - .../schema/delta/28/receipts_user_id_index.sql | 22 - .../main/schema/delta/28/upgrade_times.sql | 21 - .../main/schema/delta/28/users_is_guest.sql | 22 - .../main/schema/delta/29/push_actions.sql | 35 - .../main/schema/delta/30/alias_creator.sql | 16 - .../data_stores/main/schema/delta/30/as_users.py | 67 - .../main/schema/delta/30/deleted_pushers.sql | 25 - .../main/schema/delta/30/presence_stream.sql | 30 - .../main/schema/delta/30/public_rooms.sql | 23 - .../main/schema/delta/30/push_rule_stream.sql | 38 - .../delta/30/threepid_guest_access_tokens.sql | 24 - .../data_stores/main/schema/delta/31/invites.sql | 42 - .../delta/31/local_media_repository_url_cache.sql | 27 - .../data_stores/main/schema/delta/31/pushers.py | 87 - .../main/schema/delta/31/pushers_index.sql | 22 - .../main/schema/delta/31/search_update.py | 64 - .../data_stores/main/schema/delta/32/events.sql | 16 - .../data_stores/main/schema/delta/32/openid.sql | 9 - .../main/schema/delta/32/pusher_throttle.sql | 23 - .../main/schema/delta/32/remove_indices.sql | 33 - .../data_stores/main/schema/delta/32/reports.sql | 25 - .../schema/delta/33/access_tokens_device_index.sql | 17 - .../data_stores/main/schema/delta/33/devices.sql | 21 - .../main/schema/delta/33/devices_for_e2e_keys.sql | 19 - .../devices_for_e2e_keys_clear_unknown_device.sql | 20 - .../main/schema/delta/33/event_fields.py | 59 - .../main/schema/delta/33/remote_media_ts.py | 30 - .../main/schema/delta/33/user_ips_index.sql | 17 - .../main/schema/delta/34/appservice_stream.sql | 23 - .../main/schema/delta/34/cache_stream.py | 46 - .../main/schema/delta/34/device_inbox.sql | 24 - .../schema/delta/34/push_display_name_rename.sql | 20 - .../main/schema/delta/34/received_txn_purge.py | 32 - .../main/schema/delta/35/contains_url.sql | 17 - .../main/schema/delta/35/device_outbox.sql | 39 - .../main/schema/delta/35/device_stream_id.sql | 21 - .../schema/delta/35/event_push_actions_index.sql | 17 - .../delta/35/public_room_list_change_stream.sql | 33 - .../schema/delta/35/stream_order_to_extrem.sql | 37 - .../main/schema/delta/36/readd_public_rooms.sql | 26 - .../main/schema/delta/37/remove_auth_idx.py | 85 - .../main/schema/delta/37/user_threepids.sql | 52 - .../main/schema/delta/38/postgres_fts_gist.sql | 19 - .../main/schema/delta/39/appservice_room_list.sql | 29 - .../delta/39/device_federation_stream_idx.sql | 16 - .../main/schema/delta/39/event_push_index.sql | 17 - .../schema/delta/39/federation_out_position.sql | 22 - .../main/schema/delta/39/membership_profile.sql | 20 - .../main/schema/delta/40/current_state_idx.sql | 17 - .../main/schema/delta/40/device_inbox.sql | 21 - .../main/schema/delta/40/device_list_streams.sql | 60 - .../main/schema/delta/40/event_push_summary.sql | 37 - .../data_stores/main/schema/delta/40/pushers.sql | 39 - .../schema/delta/41/device_list_stream_idx.sql | 17 - .../main/schema/delta/41/device_outbound_index.sql | 16 - .../schema/delta/41/event_search_event_id_idx.sql | 17 - .../data_stores/main/schema/delta/41/ratelimit.sql | 22 - .../main/schema/delta/42/current_state_delta.sql | 26 - .../main/schema/delta/42/device_list_last_id.sql | 33 - .../main/schema/delta/42/event_auth_state_only.sql | 17 - .../data_stores/main/schema/delta/42/user_dir.py | 84 - .../main/schema/delta/43/blocked_rooms.sql | 21 - .../main/schema/delta/43/quarantine_media.sql | 17 - .../data_stores/main/schema/delta/43/url_cache.sql | 16 - .../main/schema/delta/43/user_share.sql | 33 - .../main/schema/delta/44/expire_url_cache.sql | 41 - .../main/schema/delta/45/group_server.sql | 167 -- .../main/schema/delta/45/profile_cache.sql | 28 - .../main/schema/delta/46/drop_refresh_tokens.sql | 17 - .../delta/46/drop_unique_deleted_pushers.sql | 35 - .../main/schema/delta/46/group_server.sql | 32 - .../delta/46/local_media_repository_url_idx.sql | 24 - .../schema/delta/46/user_dir_null_room_ids.sql | 35 - .../main/schema/delta/46/user_dir_typos.sql | 24 - .../main/schema/delta/47/last_access_media.sql | 16 - .../main/schema/delta/47/postgres_fts_gin.sql | 17 - .../main/schema/delta/47/push_actions_staging.sql | 28 - .../main/schema/delta/48/add_user_consent.sql | 18 - .../delta/48/add_user_ips_last_seen_index.sql | 17 - .../main/schema/delta/48/deactivated_users.sql | 25 - .../main/schema/delta/48/group_unique_indexes.py | 63 - .../main/schema/delta/48/groups_joinable.sql | 22 - .../49/add_user_consent_server_notice_sent.sql | 20 - .../main/schema/delta/49/add_user_daily_visits.sql | 21 - .../delta/49/add_user_ips_last_seen_only_index.sql | 17 - .../delta/50/add_creation_ts_users_index.sql | 19 - .../main/schema/delta/50/erasure_store.sql | 21 - .../schema/delta/50/make_event_content_nullable.py | 96 - .../main/schema/delta/51/e2e_room_keys.sql | 39 - .../main/schema/delta/51/monthly_active_users.sql | 27 - .../delta/52/add_event_to_state_group_index.sql | 19 - .../delta/52/device_list_streams_unique_idx.sql | 36 - .../main/schema/delta/52/e2e_room_keys.sql | 53 - .../schema/delta/53/add_user_type_to_users.sql | 19 - .../schema/delta/53/drop_sent_transactions.sql | 16 - .../main/schema/delta/53/event_format_version.sql | 16 - .../main/schema/delta/53/user_dir_populate.sql | 30 - .../main/schema/delta/53/user_ips_index.sql | 30 - .../main/schema/delta/53/user_share.sql | 44 - .../main/schema/delta/53/user_threepid_id.sql | 29 - .../main/schema/delta/53/users_in_public_rooms.sql | 28 - .../delta/54/account_validity_with_renewal.sql | 30 - .../delta/54/add_validity_to_server_keys.sql | 23 - .../schema/delta/54/delete_forward_extremities.sql | 23 - .../main/schema/delta/54/drop_legacy_tables.sql | 30 - .../main/schema/delta/54/drop_presence_list.sql | 16 - .../data_stores/main/schema/delta/54/relations.sql | 27 - .../data_stores/main/schema/delta/54/stats.sql | 80 - .../data_stores/main/schema/delta/54/stats2.sql | 28 - .../main/schema/delta/55/access_token_expiry.sql | 18 - .../schema/delta/55/track_threepid_validations.sql | 31 - .../schema/delta/55/users_alter_deactivated.sql | 19 - .../schema/delta/56/add_spans_to_device_lists.sql | 20 - .../delta/56/current_state_events_membership.sql | 22 - .../56/current_state_events_membership_mk2.sql | 24 - .../delta/56/delete_keys_from_deleted_backups.sql | 25 - .../schema/delta/56/destinations_failure_ts.sql | 25 - .../destinations_retry_interval_type.sql.postgres | 18 - .../schema/delta/56/device_stream_id_insert.sql | 20 - .../main/schema/delta/56/devices_last_seen.sql | 24 - .../schema/delta/56/drop_unused_event_tables.sql | 20 - .../main/schema/delta/56/event_expiry.sql | 21 - .../main/schema/delta/56/event_labels.sql | 30 - .../delta/56/event_labels_background_update.sql | 17 - .../main/schema/delta/56/fix_room_keys_index.sql | 18 - .../main/schema/delta/56/hidden_devices.sql | 18 - .../schema/delta/56/hidden_devices_fix.sql.sqlite | 42 - .../delta/56/nuke_empty_communities_from_db.sql | 29 - .../main/schema/delta/56/public_room_list_idx.sql | 16 - .../main/schema/delta/56/redaction_censor.sql | 16 - .../main/schema/delta/56/redaction_censor2.sql | 22 - .../56/redaction_censor3_fix_update.sql.postgres | 25 - .../main/schema/delta/56/redaction_censor4.sql | 16 - .../56/remove_tombstoned_rooms_from_directory.sql | 18 - .../main/schema/delta/56/room_key_etag.sql | 17 - .../main/schema/delta/56/room_membership_idx.sql | 18 - .../main/schema/delta/56/room_retention.sql | 33 - .../main/schema/delta/56/signing_keys.sql | 56 - .../delta/56/signing_keys_nonunique_signatures.sql | 22 - .../main/schema/delta/56/stats_separated.sql | 156 -- .../schema/delta/56/unique_user_filter_index.py | 52 - .../main/schema/delta/56/user_external_ids.sql | 24 - .../schema/delta/56/users_in_public_rooms_idx.sql | 17 - .../delta/57/delete_old_current_state_events.sql | 22 - .../delta/57/device_list_remote_cache_stale.sql | 25 - .../schema/delta/57/local_current_membership.py | 98 - .../schema/delta/57/remove_sent_outbound_pokes.sql | 21 - .../main/schema/delta/57/rooms_version_column.sql | 24 - .../delta/57/rooms_version_column_2.sql.postgres | 35 - .../delta/57/rooms_version_column_2.sql.sqlite | 22 - .../delta/57/rooms_version_column_3.sql.postgres | 39 - .../delta/57/rooms_version_column_3.sql.sqlite | 23 - .../delta/58/02remove_dup_outbound_pokes.sql | 22 - .../main/schema/delta/58/03persist_ui_auth.sql | 36 - .../schema/delta/58/05cache_instance.sql.postgres | 30 - .../main/schema/delta/58/06dlols_unique_idx.py | 80 - .../58/08_media_safe_from_quarantine.sql.postgres | 18 - .../58/08_media_safe_from_quarantine.sql.sqlite | 18 - .../delta/58/10drop_local_rejections_stream.sql | 22 - .../delta/58/10federation_pos_instance_name.sql | 22 - .../main/schema/delta/58/11user_id_seq.py | 34 - .../main/schema/delta/58/12room_stats.sql | 32 - .../main/schema/delta/58/12unread_messages.sql | 18 - .../full_schemas/16/application_services.sql | 37 - .../main/schema/full_schemas/16/event_edges.sql | 70 - .../schema/full_schemas/16/event_signatures.sql | 38 - .../data_stores/main/schema/full_schemas/16/im.sql | 120 -- .../main/schema/full_schemas/16/keys.sql | 26 - .../schema/full_schemas/16/media_repository.sql | 68 - .../main/schema/full_schemas/16/presence.sql | 32 - .../main/schema/full_schemas/16/profiles.sql | 20 - .../main/schema/full_schemas/16/push.sql | 74 - .../main/schema/full_schemas/16/redactions.sql | 22 - .../main/schema/full_schemas/16/room_aliases.sql | 29 - .../main/schema/full_schemas/16/state.sql | 40 - .../main/schema/full_schemas/16/transactions.sql | 44 - .../main/schema/full_schemas/16/users.sql | 42 - .../main/schema/full_schemas/54/full.sql.postgres | 1983 -------------------- .../main/schema/full_schemas/54/full.sql.sqlite | 253 --- .../schema/full_schemas/54/stream_positions.sql | 8 - .../data_stores/main/schema/full_schemas/README.md | 21 - synapse/storage/data_stores/main/search.py | 708 ------- synapse/storage/data_stores/main/signatures.py | 71 - synapse/storage/data_stores/main/state.py | 509 ----- synapse/storage/data_stores/main/state_deltas.py | 121 -- synapse/storage/data_stores/main/stats.py | 878 --------- synapse/storage/data_stores/main/stream.py | 1064 ----------- synapse/storage/data_stores/main/tags.py | 288 --- synapse/storage/data_stores/main/transactions.py | 269 --- synapse/storage/data_stores/main/ui_auth.py | 300 --- synapse/storage/data_stores/main/user_directory.py | 837 --------- .../storage/data_stores/main/user_erasure_store.py | 113 -- synapse/storage/data_stores/state/__init__.py | 16 - synapse/storage/data_stores/state/bg_updates.py | 372 ---- .../state/schema/delta/23/drop_state_index.sql | 16 - .../state/schema/delta/30/state_stream.sql | 33 - .../state/schema/delta/32/remove_state_indices.sql | 19 - .../state/schema/delta/35/add_state_index.sql | 17 - .../data_stores/state/schema/delta/35/state.sql | 22 - .../state/schema/delta/35/state_dedupe.sql | 17 - .../state/schema/delta/47/state_group_seq.py | 34 - .../state/schema/delta/56/state_group_room_idx.sql | 17 - .../state/schema/full_schemas/54/full.sql | 37 - .../schema/full_schemas/54/sequence.sql.postgres | 21 - synapse/storage/data_stores/state/store.py | 644 ------- synapse/storage/database.py | 2 +- synapse/storage/databases/__init__.py | 97 + synapse/storage/databases/main/__init__.py | 596 ++++++ synapse/storage/databases/main/account_data.py | 430 +++++ synapse/storage/databases/main/appservice.py | 374 ++++ synapse/storage/databases/main/cache.py | 307 +++ synapse/storage/databases/main/censor_events.py | 210 +++ synapse/storage/databases/main/client_ips.py | 580 ++++++ synapse/storage/databases/main/deviceinbox.py | 476 +++++ synapse/storage/databases/main/devices.py | 1311 +++++++++++++ synapse/storage/databases/main/directory.py | 195 ++ synapse/storage/databases/main/e2e_room_keys.py | 439 +++++ synapse/storage/databases/main/end_to_end_keys.py | 748 ++++++++ synapse/storage/databases/main/event_federation.py | 726 +++++++ .../storage/databases/main/event_push_actions.py | 885 +++++++++ synapse/storage/databases/main/events.py | 1527 +++++++++++++++ .../storage/databases/main/events_bg_updates.py | 585 ++++++ synapse/storage/databases/main/events_worker.py | 1454 ++++++++++++++ synapse/storage/databases/main/filtering.py | 74 + synapse/storage/databases/main/group_server.py | 1297 +++++++++++++ synapse/storage/databases/main/keys.py | 210 +++ synapse/storage/databases/main/media_repository.py | 398 ++++ synapse/storage/databases/main/metrics.py | 130 ++ .../storage/databases/main/monthly_active_users.py | 361 ++++ synapse/storage/databases/main/openid.py | 33 + synapse/storage/databases/main/presence.py | 186 ++ synapse/storage/databases/main/profile.py | 178 ++ synapse/storage/databases/main/purge_events.py | 400 ++++ synapse/storage/databases/main/push_rule.py | 759 ++++++++ synapse/storage/databases/main/pusher.py | 356 ++++ synapse/storage/databases/main/receipts.py | 591 ++++++ synapse/storage/databases/main/registration.py | 1588 ++++++++++++++++ synapse/storage/databases/main/rejections.py | 31 + synapse/storage/databases/main/relations.py | 327 ++++ synapse/storage/databases/main/room.py | 1429 ++++++++++++++ synapse/storage/databases/main/roommember.py | 1139 +++++++++++ .../storage/databases/main/schema/delta/12/v12.sql | 63 + .../storage/databases/main/schema/delta/13/v13.sql | 19 + .../storage/databases/main/schema/delta/14/v14.sql | 23 + .../main/schema/delta/15/appservice_txns.sql | 31 + .../main/schema/delta/15/presence_indices.sql | 2 + .../storage/databases/main/schema/delta/15/v15.sql | 24 + .../main/schema/delta/16/events_order_index.sql | 4 + .../schema/delta/16/remote_media_cache_index.sql | 2 + .../main/schema/delta/16/remove_duplicates.sql | 9 + .../main/schema/delta/16/room_alias_index.sql | 3 + .../main/schema/delta/16/unique_constraints.sql | 72 + .../databases/main/schema/delta/16/users.sql | 56 + .../main/schema/delta/17/drop_indexes.sql | 18 + .../databases/main/schema/delta/17/server_keys.sql | 24 + .../main/schema/delta/17/user_threepids.sql | 9 + .../schema/delta/18/server_keys_bigger_ints.sql | 32 + .../databases/main/schema/delta/19/event_index.sql | 19 + .../databases/main/schema/delta/20/dummy.sql | 1 + .../databases/main/schema/delta/20/pushers.py | 88 + .../main/schema/delta/21/end_to_end_keys.sql | 34 + .../databases/main/schema/delta/21/receipts.sql | 38 + .../main/schema/delta/22/receipts_index.sql | 22 + .../main/schema/delta/22/user_threepids_unique.sql | 19 + .../main/schema/delta/24/stats_reporting.sql | 18 + .../storage/databases/main/schema/delta/25/fts.py | 80 + .../main/schema/delta/25/guest_access.sql | 25 + .../main/schema/delta/25/history_visibility.sql | 25 + .../databases/main/schema/delta/25/tags.sql | 38 + .../main/schema/delta/26/account_data.sql | 17 + .../main/schema/delta/27/account_data.sql | 36 + .../main/schema/delta/27/forgotten_memberships.sql | 26 + .../storage/databases/main/schema/delta/27/ts.py | 59 + .../main/schema/delta/28/event_push_actions.sql | 27 + .../main/schema/delta/28/events_room_stream.sql | 20 + .../main/schema/delta/28/public_roms_index.sql | 20 + .../schema/delta/28/receipts_user_id_index.sql | 22 + .../main/schema/delta/28/upgrade_times.sql | 21 + .../main/schema/delta/28/users_is_guest.sql | 22 + .../main/schema/delta/29/push_actions.sql | 35 + .../main/schema/delta/30/alias_creator.sql | 16 + .../databases/main/schema/delta/30/as_users.py | 67 + .../main/schema/delta/30/deleted_pushers.sql | 25 + .../main/schema/delta/30/presence_stream.sql | 30 + .../main/schema/delta/30/public_rooms.sql | 23 + .../main/schema/delta/30/push_rule_stream.sql | 38 + .../delta/30/threepid_guest_access_tokens.sql | 24 + .../databases/main/schema/delta/31/invites.sql | 42 + .../delta/31/local_media_repository_url_cache.sql | 27 + .../databases/main/schema/delta/31/pushers.py | 87 + .../main/schema/delta/31/pushers_index.sql | 22 + .../main/schema/delta/31/search_update.py | 64 + .../databases/main/schema/delta/32/events.sql | 16 + .../databases/main/schema/delta/32/openid.sql | 9 + .../main/schema/delta/32/pusher_throttle.sql | 23 + .../main/schema/delta/32/remove_indices.sql | 33 + .../databases/main/schema/delta/32/reports.sql | 25 + .../schema/delta/33/access_tokens_device_index.sql | 17 + .../databases/main/schema/delta/33/devices.sql | 21 + .../main/schema/delta/33/devices_for_e2e_keys.sql | 19 + .../devices_for_e2e_keys_clear_unknown_device.sql | 20 + .../databases/main/schema/delta/33/event_fields.py | 59 + .../main/schema/delta/33/remote_media_ts.py | 30 + .../main/schema/delta/33/user_ips_index.sql | 17 + .../main/schema/delta/34/appservice_stream.sql | 23 + .../databases/main/schema/delta/34/cache_stream.py | 46 + .../main/schema/delta/34/device_inbox.sql | 24 + .../schema/delta/34/push_display_name_rename.sql | 20 + .../main/schema/delta/34/received_txn_purge.py | 32 + .../main/schema/delta/35/contains_url.sql | 17 + .../main/schema/delta/35/device_outbox.sql | 39 + .../main/schema/delta/35/device_stream_id.sql | 21 + .../schema/delta/35/event_push_actions_index.sql | 17 + .../delta/35/public_room_list_change_stream.sql | 33 + .../schema/delta/35/stream_order_to_extrem.sql | 37 + .../main/schema/delta/36/readd_public_rooms.sql | 26 + .../main/schema/delta/37/remove_auth_idx.py | 85 + .../main/schema/delta/37/user_threepids.sql | 52 + .../main/schema/delta/38/postgres_fts_gist.sql | 19 + .../main/schema/delta/39/appservice_room_list.sql | 29 + .../delta/39/device_federation_stream_idx.sql | 16 + .../main/schema/delta/39/event_push_index.sql | 17 + .../schema/delta/39/federation_out_position.sql | 22 + .../main/schema/delta/39/membership_profile.sql | 20 + .../main/schema/delta/40/current_state_idx.sql | 17 + .../main/schema/delta/40/device_inbox.sql | 21 + .../main/schema/delta/40/device_list_streams.sql | 60 + .../main/schema/delta/40/event_push_summary.sql | 37 + .../databases/main/schema/delta/40/pushers.sql | 39 + .../schema/delta/41/device_list_stream_idx.sql | 17 + .../main/schema/delta/41/device_outbound_index.sql | 16 + .../schema/delta/41/event_search_event_id_idx.sql | 17 + .../databases/main/schema/delta/41/ratelimit.sql | 22 + .../main/schema/delta/42/current_state_delta.sql | 26 + .../main/schema/delta/42/device_list_last_id.sql | 33 + .../main/schema/delta/42/event_auth_state_only.sql | 17 + .../databases/main/schema/delta/42/user_dir.py | 84 + .../main/schema/delta/43/blocked_rooms.sql | 21 + .../main/schema/delta/43/quarantine_media.sql | 17 + .../databases/main/schema/delta/43/url_cache.sql | 16 + .../databases/main/schema/delta/43/user_share.sql | 33 + .../main/schema/delta/44/expire_url_cache.sql | 41 + .../main/schema/delta/45/group_server.sql | 167 ++ .../main/schema/delta/45/profile_cache.sql | 28 + .../main/schema/delta/46/drop_refresh_tokens.sql | 17 + .../delta/46/drop_unique_deleted_pushers.sql | 35 + .../main/schema/delta/46/group_server.sql | 32 + .../delta/46/local_media_repository_url_idx.sql | 24 + .../schema/delta/46/user_dir_null_room_ids.sql | 35 + .../main/schema/delta/46/user_dir_typos.sql | 24 + .../main/schema/delta/47/last_access_media.sql | 16 + .../main/schema/delta/47/postgres_fts_gin.sql | 17 + .../main/schema/delta/47/push_actions_staging.sql | 28 + .../main/schema/delta/48/add_user_consent.sql | 18 + .../delta/48/add_user_ips_last_seen_index.sql | 17 + .../main/schema/delta/48/deactivated_users.sql | 25 + .../main/schema/delta/48/group_unique_indexes.py | 63 + .../main/schema/delta/48/groups_joinable.sql | 22 + .../49/add_user_consent_server_notice_sent.sql | 20 + .../main/schema/delta/49/add_user_daily_visits.sql | 21 + .../delta/49/add_user_ips_last_seen_only_index.sql | 17 + .../delta/50/add_creation_ts_users_index.sql | 19 + .../main/schema/delta/50/erasure_store.sql | 21 + .../schema/delta/50/make_event_content_nullable.py | 96 + .../main/schema/delta/51/e2e_room_keys.sql | 39 + .../main/schema/delta/51/monthly_active_users.sql | 27 + .../delta/52/add_event_to_state_group_index.sql | 19 + .../delta/52/device_list_streams_unique_idx.sql | 36 + .../main/schema/delta/52/e2e_room_keys.sql | 53 + .../schema/delta/53/add_user_type_to_users.sql | 19 + .../schema/delta/53/drop_sent_transactions.sql | 16 + .../main/schema/delta/53/event_format_version.sql | 16 + .../main/schema/delta/53/user_dir_populate.sql | 30 + .../main/schema/delta/53/user_ips_index.sql | 30 + .../databases/main/schema/delta/53/user_share.sql | 44 + .../main/schema/delta/53/user_threepid_id.sql | 29 + .../main/schema/delta/53/users_in_public_rooms.sql | 28 + .../delta/54/account_validity_with_renewal.sql | 30 + .../delta/54/add_validity_to_server_keys.sql | 23 + .../schema/delta/54/delete_forward_extremities.sql | 23 + .../main/schema/delta/54/drop_legacy_tables.sql | 30 + .../main/schema/delta/54/drop_presence_list.sql | 16 + .../databases/main/schema/delta/54/relations.sql | 27 + .../databases/main/schema/delta/54/stats.sql | 80 + .../databases/main/schema/delta/54/stats2.sql | 28 + .../main/schema/delta/55/access_token_expiry.sql | 18 + .../schema/delta/55/track_threepid_validations.sql | 31 + .../schema/delta/55/users_alter_deactivated.sql | 19 + .../schema/delta/56/add_spans_to_device_lists.sql | 20 + .../delta/56/current_state_events_membership.sql | 22 + .../56/current_state_events_membership_mk2.sql | 24 + .../delta/56/delete_keys_from_deleted_backups.sql | 25 + .../schema/delta/56/destinations_failure_ts.sql | 25 + .../destinations_retry_interval_type.sql.postgres | 18 + .../schema/delta/56/device_stream_id_insert.sql | 20 + .../main/schema/delta/56/devices_last_seen.sql | 24 + .../schema/delta/56/drop_unused_event_tables.sql | 20 + .../main/schema/delta/56/event_expiry.sql | 21 + .../main/schema/delta/56/event_labels.sql | 30 + .../delta/56/event_labels_background_update.sql | 17 + .../main/schema/delta/56/fix_room_keys_index.sql | 18 + .../main/schema/delta/56/hidden_devices.sql | 18 + .../schema/delta/56/hidden_devices_fix.sql.sqlite | 42 + .../delta/56/nuke_empty_communities_from_db.sql | 29 + .../main/schema/delta/56/public_room_list_idx.sql | 16 + .../main/schema/delta/56/redaction_censor.sql | 16 + .../main/schema/delta/56/redaction_censor2.sql | 22 + .../56/redaction_censor3_fix_update.sql.postgres | 25 + .../main/schema/delta/56/redaction_censor4.sql | 16 + .../56/remove_tombstoned_rooms_from_directory.sql | 18 + .../main/schema/delta/56/room_key_etag.sql | 17 + .../main/schema/delta/56/room_membership_idx.sql | 18 + .../main/schema/delta/56/room_retention.sql | 33 + .../main/schema/delta/56/signing_keys.sql | 56 + .../delta/56/signing_keys_nonunique_signatures.sql | 22 + .../main/schema/delta/56/stats_separated.sql | 156 ++ .../schema/delta/56/unique_user_filter_index.py | 52 + .../main/schema/delta/56/user_external_ids.sql | 24 + .../schema/delta/56/users_in_public_rooms_idx.sql | 17 + .../delta/57/delete_old_current_state_events.sql | 22 + .../delta/57/device_list_remote_cache_stale.sql | 25 + .../schema/delta/57/local_current_membership.py | 98 + .../schema/delta/57/remove_sent_outbound_pokes.sql | 21 + .../main/schema/delta/57/rooms_version_column.sql | 24 + .../delta/57/rooms_version_column_2.sql.postgres | 35 + .../delta/57/rooms_version_column_2.sql.sqlite | 22 + .../delta/57/rooms_version_column_3.sql.postgres | 39 + .../delta/57/rooms_version_column_3.sql.sqlite | 23 + .../delta/58/02remove_dup_outbound_pokes.sql | 22 + .../main/schema/delta/58/03persist_ui_auth.sql | 36 + .../schema/delta/58/05cache_instance.sql.postgres | 30 + .../main/schema/delta/58/06dlols_unique_idx.py | 80 + .../58/08_media_safe_from_quarantine.sql.postgres | 18 + .../58/08_media_safe_from_quarantine.sql.sqlite | 18 + .../delta/58/10drop_local_rejections_stream.sql | 22 + .../delta/58/10federation_pos_instance_name.sql | 22 + .../main/schema/delta/58/11user_id_seq.py | 34 + .../main/schema/delta/58/12room_stats.sql | 32 + .../main/schema/delta/58/12unread_messages.sql | 18 + .../full_schemas/16/application_services.sql | 37 + .../main/schema/full_schemas/16/event_edges.sql | 70 + .../schema/full_schemas/16/event_signatures.sql | 38 + .../databases/main/schema/full_schemas/16/im.sql | 120 ++ .../databases/main/schema/full_schemas/16/keys.sql | 26 + .../schema/full_schemas/16/media_repository.sql | 68 + .../main/schema/full_schemas/16/presence.sql | 32 + .../main/schema/full_schemas/16/profiles.sql | 20 + .../databases/main/schema/full_schemas/16/push.sql | 74 + .../main/schema/full_schemas/16/redactions.sql | 22 + .../main/schema/full_schemas/16/room_aliases.sql | 29 + .../main/schema/full_schemas/16/state.sql | 40 + .../main/schema/full_schemas/16/transactions.sql | 44 + .../main/schema/full_schemas/16/users.sql | 42 + .../main/schema/full_schemas/54/full.sql.postgres | 1983 ++++++++++++++++++++ .../main/schema/full_schemas/54/full.sql.sqlite | 253 +++ .../schema/full_schemas/54/stream_positions.sql | 8 + .../databases/main/schema/full_schemas/README.md | 21 + synapse/storage/databases/main/search.py | 710 +++++++ synapse/storage/databases/main/signatures.py | 71 + synapse/storage/databases/main/state.py | 509 +++++ synapse/storage/databases/main/state_deltas.py | 121 ++ synapse/storage/databases/main/stats.py | 886 +++++++++ synapse/storage/databases/main/stream.py | 1064 +++++++++++ synapse/storage/databases/main/tags.py | 288 +++ synapse/storage/databases/main/transactions.py | 269 +++ synapse/storage/databases/main/ui_auth.py | 300 +++ synapse/storage/databases/main/user_directory.py | 847 +++++++++ .../storage/databases/main/user_erasure_store.py | 113 ++ synapse/storage/databases/state/__init__.py | 16 + synapse/storage/databases/state/bg_updates.py | 374 ++++ .../state/schema/delta/23/drop_state_index.sql | 16 + .../state/schema/delta/30/state_stream.sql | 33 + .../state/schema/delta/32/remove_state_indices.sql | 19 + .../state/schema/delta/35/add_state_index.sql | 17 + .../databases/state/schema/delta/35/state.sql | 22 + .../state/schema/delta/35/state_dedupe.sql | 17 + .../state/schema/delta/47/state_group_seq.py | 34 + .../state/schema/delta/56/state_group_room_idx.sql | 17 + .../state/schema/full_schemas/54/full.sql | 37 + .../schema/full_schemas/54/sequence.sql.postgres | 21 + synapse/storage/databases/state/store.py | 644 +++++++ synapse/storage/persist_events.py | 6 +- synapse/storage/prepare_database.py | 48 +- synapse/storage/util/id_generators.py | 4 +- synmark/__init__.py | 6 +- tests/handlers/test_stats.py | 62 +- tests/handlers/test_user_directory.py | 22 +- tests/replication/_base.py | 6 +- tests/rest/admin/test_room.py | 4 +- tests/storage/test__base.py | 16 +- tests/storage/test_appservice.py | 6 +- tests/storage/test_background_update.py | 8 +- tests/storage/test_base.py | 22 +- tests/storage/test_cleanup_extrems.py | 12 +- tests/storage/test_client_ips.py | 26 +- tests/storage/test_event_federation.py | 16 +- tests/storage/test_event_push_actions.py | 12 +- tests/storage/test_id_generators.py | 14 +- tests/storage/test_monthly_active_users.py | 6 +- tests/storage/test_redaction.py | 4 +- tests/storage/test_roommember.py | 12 +- tests/unittest.py | 6 +- tox.ini | 2 +- 612 files changed, 36087 insertions(+), 36002 deletions(-) create mode 100644 changelog.d/8033.misc delete mode 100644 synapse/storage/data_stores/__init__.py delete mode 100644 synapse/storage/data_stores/main/__init__.py delete mode 100644 synapse/storage/data_stores/main/account_data.py delete mode 100644 synapse/storage/data_stores/main/appservice.py delete mode 100644 synapse/storage/data_stores/main/cache.py delete mode 100644 synapse/storage/data_stores/main/censor_events.py delete mode 100644 synapse/storage/data_stores/main/client_ips.py delete mode 100644 synapse/storage/data_stores/main/deviceinbox.py delete mode 100644 synapse/storage/data_stores/main/devices.py delete mode 100644 synapse/storage/data_stores/main/directory.py delete mode 100644 synapse/storage/data_stores/main/e2e_room_keys.py delete mode 100644 synapse/storage/data_stores/main/end_to_end_keys.py delete mode 100644 synapse/storage/data_stores/main/event_federation.py delete mode 100644 synapse/storage/data_stores/main/event_push_actions.py delete mode 100644 synapse/storage/data_stores/main/events.py delete mode 100644 synapse/storage/data_stores/main/events_bg_updates.py delete mode 100644 synapse/storage/data_stores/main/events_worker.py delete mode 100644 synapse/storage/data_stores/main/filtering.py delete mode 100644 synapse/storage/data_stores/main/group_server.py delete mode 100644 synapse/storage/data_stores/main/keys.py delete mode 100644 synapse/storage/data_stores/main/media_repository.py delete mode 100644 synapse/storage/data_stores/main/metrics.py delete mode 100644 synapse/storage/data_stores/main/monthly_active_users.py delete mode 100644 synapse/storage/data_stores/main/openid.py delete mode 100644 synapse/storage/data_stores/main/presence.py delete mode 100644 synapse/storage/data_stores/main/profile.py delete mode 100644 synapse/storage/data_stores/main/purge_events.py delete mode 100644 synapse/storage/data_stores/main/push_rule.py delete mode 100644 synapse/storage/data_stores/main/pusher.py delete mode 100644 synapse/storage/data_stores/main/receipts.py delete mode 100644 synapse/storage/data_stores/main/registration.py delete mode 100644 synapse/storage/data_stores/main/rejections.py delete mode 100644 synapse/storage/data_stores/main/relations.py delete mode 100644 synapse/storage/data_stores/main/room.py delete mode 100644 synapse/storage/data_stores/main/roommember.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/12/v12.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/13/v13.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/14/v14.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/15/v15.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/16/users.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/17/server_keys.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/19/event_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/20/dummy.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/20/pushers.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/21/receipts.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/25/fts.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/25/guest_access.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/25/tags.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/26/account_data.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/27/account_data.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/27/ts.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/29/push_actions.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/30/as_users.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/31/invites.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/31/pushers.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/31/search_update.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/32/events.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/32/openid.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/32/reports.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/33/devices.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/33/event_fields.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/34/cache_stream.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/35/contains_url.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/40/pushers.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/42/user_dir.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/43/url_cache.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/43/user_share.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/45/group_server.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/46/group_server.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/53/user_share.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/54/relations.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/54/stats.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/54/stats2.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/event_labels.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/room_retention.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres delete mode 100644 synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite delete mode 100644 synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres delete mode 100644 synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/12room_stats.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/im.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/push.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/state.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/users.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/README.md delete mode 100644 synapse/storage/data_stores/main/search.py delete mode 100644 synapse/storage/data_stores/main/signatures.py delete mode 100644 synapse/storage/data_stores/main/state.py delete mode 100644 synapse/storage/data_stores/main/state_deltas.py delete mode 100644 synapse/storage/data_stores/main/stats.py delete mode 100644 synapse/storage/data_stores/main/stream.py delete mode 100644 synapse/storage/data_stores/main/tags.py delete mode 100644 synapse/storage/data_stores/main/transactions.py delete mode 100644 synapse/storage/data_stores/main/ui_auth.py delete mode 100644 synapse/storage/data_stores/main/user_directory.py delete mode 100644 synapse/storage/data_stores/main/user_erasure_store.py delete mode 100644 synapse/storage/data_stores/state/__init__.py delete mode 100644 synapse/storage/data_stores/state/bg_updates.py delete mode 100644 synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql delete mode 100644 synapse/storage/data_stores/state/schema/delta/30/state_stream.sql delete mode 100644 synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql delete mode 100644 synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql delete mode 100644 synapse/storage/data_stores/state/schema/delta/35/state.sql delete mode 100644 synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql delete mode 100644 synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py delete mode 100644 synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql delete mode 100644 synapse/storage/data_stores/state/schema/full_schemas/54/full.sql delete mode 100644 synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres delete mode 100644 synapse/storage/data_stores/state/store.py create mode 100644 synapse/storage/databases/__init__.py create mode 100644 synapse/storage/databases/main/__init__.py create mode 100644 synapse/storage/databases/main/account_data.py create mode 100644 synapse/storage/databases/main/appservice.py create mode 100644 synapse/storage/databases/main/cache.py create mode 100644 synapse/storage/databases/main/censor_events.py create mode 100644 synapse/storage/databases/main/client_ips.py create mode 100644 synapse/storage/databases/main/deviceinbox.py create mode 100644 synapse/storage/databases/main/devices.py create mode 100644 synapse/storage/databases/main/directory.py create mode 100644 synapse/storage/databases/main/e2e_room_keys.py create mode 100644 synapse/storage/databases/main/end_to_end_keys.py create mode 100644 synapse/storage/databases/main/event_federation.py create mode 100644 synapse/storage/databases/main/event_push_actions.py create mode 100644 synapse/storage/databases/main/events.py create mode 100644 synapse/storage/databases/main/events_bg_updates.py create mode 100644 synapse/storage/databases/main/events_worker.py create mode 100644 synapse/storage/databases/main/filtering.py create mode 100644 synapse/storage/databases/main/group_server.py create mode 100644 synapse/storage/databases/main/keys.py create mode 100644 synapse/storage/databases/main/media_repository.py create mode 100644 synapse/storage/databases/main/metrics.py create mode 100644 synapse/storage/databases/main/monthly_active_users.py create mode 100644 synapse/storage/databases/main/openid.py create mode 100644 synapse/storage/databases/main/presence.py create mode 100644 synapse/storage/databases/main/profile.py create mode 100644 synapse/storage/databases/main/purge_events.py create mode 100644 synapse/storage/databases/main/push_rule.py create mode 100644 synapse/storage/databases/main/pusher.py create mode 100644 synapse/storage/databases/main/receipts.py create mode 100644 synapse/storage/databases/main/registration.py create mode 100644 synapse/storage/databases/main/rejections.py create mode 100644 synapse/storage/databases/main/relations.py create mode 100644 synapse/storage/databases/main/room.py create mode 100644 synapse/storage/databases/main/roommember.py create mode 100644 synapse/storage/databases/main/schema/delta/12/v12.sql create mode 100644 synapse/storage/databases/main/schema/delta/13/v13.sql create mode 100644 synapse/storage/databases/main/schema/delta/14/v14.sql create mode 100644 synapse/storage/databases/main/schema/delta/15/appservice_txns.sql create mode 100644 synapse/storage/databases/main/schema/delta/15/presence_indices.sql create mode 100644 synapse/storage/databases/main/schema/delta/15/v15.sql create mode 100644 synapse/storage/databases/main/schema/delta/16/events_order_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql create mode 100644 synapse/storage/databases/main/schema/delta/16/room_alias_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/16/unique_constraints.sql create mode 100644 synapse/storage/databases/main/schema/delta/16/users.sql create mode 100644 synapse/storage/databases/main/schema/delta/17/drop_indexes.sql create mode 100644 synapse/storage/databases/main/schema/delta/17/server_keys.sql create mode 100644 synapse/storage/databases/main/schema/delta/17/user_threepids.sql create mode 100644 synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql create mode 100644 synapse/storage/databases/main/schema/delta/19/event_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/20/dummy.sql create mode 100644 synapse/storage/databases/main/schema/delta/20/pushers.py create mode 100644 synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql create mode 100644 synapse/storage/databases/main/schema/delta/21/receipts.sql create mode 100644 synapse/storage/databases/main/schema/delta/22/receipts_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql create mode 100644 synapse/storage/databases/main/schema/delta/24/stats_reporting.sql create mode 100644 synapse/storage/databases/main/schema/delta/25/fts.py create mode 100644 synapse/storage/databases/main/schema/delta/25/guest_access.sql create mode 100644 synapse/storage/databases/main/schema/delta/25/history_visibility.sql create mode 100644 synapse/storage/databases/main/schema/delta/25/tags.sql create mode 100644 synapse/storage/databases/main/schema/delta/26/account_data.sql create mode 100644 synapse/storage/databases/main/schema/delta/27/account_data.sql create mode 100644 synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql create mode 100644 synapse/storage/databases/main/schema/delta/27/ts.py create mode 100644 synapse/storage/databases/main/schema/delta/28/event_push_actions.sql create mode 100644 synapse/storage/databases/main/schema/delta/28/events_room_stream.sql create mode 100644 synapse/storage/databases/main/schema/delta/28/public_roms_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/28/upgrade_times.sql create mode 100644 synapse/storage/databases/main/schema/delta/28/users_is_guest.sql create mode 100644 synapse/storage/databases/main/schema/delta/29/push_actions.sql create mode 100644 synapse/storage/databases/main/schema/delta/30/alias_creator.sql create mode 100644 synapse/storage/databases/main/schema/delta/30/as_users.py create mode 100644 synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql create mode 100644 synapse/storage/databases/main/schema/delta/30/presence_stream.sql create mode 100644 synapse/storage/databases/main/schema/delta/30/public_rooms.sql create mode 100644 synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql create mode 100644 synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql create mode 100644 synapse/storage/databases/main/schema/delta/31/invites.sql create mode 100644 synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql create mode 100644 synapse/storage/databases/main/schema/delta/31/pushers.py create mode 100644 synapse/storage/databases/main/schema/delta/31/pushers_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/31/search_update.py create mode 100644 synapse/storage/databases/main/schema/delta/32/events.sql create mode 100644 synapse/storage/databases/main/schema/delta/32/openid.sql create mode 100644 synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql create mode 100644 synapse/storage/databases/main/schema/delta/32/remove_indices.sql create mode 100644 synapse/storage/databases/main/schema/delta/32/reports.sql create mode 100644 synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/33/devices.sql create mode 100644 synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql create mode 100644 synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql create mode 100644 synapse/storage/databases/main/schema/delta/33/event_fields.py create mode 100644 synapse/storage/databases/main/schema/delta/33/remote_media_ts.py create mode 100644 synapse/storage/databases/main/schema/delta/33/user_ips_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/34/appservice_stream.sql create mode 100644 synapse/storage/databases/main/schema/delta/34/cache_stream.py create mode 100644 synapse/storage/databases/main/schema/delta/34/device_inbox.sql create mode 100644 synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql create mode 100644 synapse/storage/databases/main/schema/delta/34/received_txn_purge.py create mode 100644 synapse/storage/databases/main/schema/delta/35/contains_url.sql create mode 100644 synapse/storage/databases/main/schema/delta/35/device_outbox.sql create mode 100644 synapse/storage/databases/main/schema/delta/35/device_stream_id.sql create mode 100644 synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql create mode 100644 synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql create mode 100644 synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql create mode 100644 synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py create mode 100644 synapse/storage/databases/main/schema/delta/37/user_threepids.sql create mode 100644 synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql create mode 100644 synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql create mode 100644 synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql create mode 100644 synapse/storage/databases/main/schema/delta/39/event_push_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/39/federation_out_position.sql create mode 100644 synapse/storage/databases/main/schema/delta/39/membership_profile.sql create mode 100644 synapse/storage/databases/main/schema/delta/40/current_state_idx.sql create mode 100644 synapse/storage/databases/main/schema/delta/40/device_inbox.sql create mode 100644 synapse/storage/databases/main/schema/delta/40/device_list_streams.sql create mode 100644 synapse/storage/databases/main/schema/delta/40/event_push_summary.sql create mode 100644 synapse/storage/databases/main/schema/delta/40/pushers.sql create mode 100644 synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql create mode 100644 synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql create mode 100644 synapse/storage/databases/main/schema/delta/41/ratelimit.sql create mode 100644 synapse/storage/databases/main/schema/delta/42/current_state_delta.sql create mode 100644 synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql create mode 100644 synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql create mode 100644 synapse/storage/databases/main/schema/delta/42/user_dir.py create mode 100644 synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql create mode 100644 synapse/storage/databases/main/schema/delta/43/quarantine_media.sql create mode 100644 synapse/storage/databases/main/schema/delta/43/url_cache.sql create mode 100644 synapse/storage/databases/main/schema/delta/43/user_share.sql create mode 100644 synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql create mode 100644 synapse/storage/databases/main/schema/delta/45/group_server.sql create mode 100644 synapse/storage/databases/main/schema/delta/45/profile_cache.sql create mode 100644 synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql create mode 100644 synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql create mode 100644 synapse/storage/databases/main/schema/delta/46/group_server.sql create mode 100644 synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql create mode 100644 synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql create mode 100644 synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql create mode 100644 synapse/storage/databases/main/schema/delta/47/last_access_media.sql create mode 100644 synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql create mode 100644 synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql create mode 100644 synapse/storage/databases/main/schema/delta/48/add_user_consent.sql create mode 100644 synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/48/deactivated_users.sql create mode 100644 synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py create mode 100644 synapse/storage/databases/main/schema/delta/48/groups_joinable.sql create mode 100644 synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql create mode 100644 synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql create mode 100644 synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/50/erasure_store.sql create mode 100644 synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py create mode 100644 synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql create mode 100644 synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql create mode 100644 synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql create mode 100644 synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql create mode 100644 synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql create mode 100644 synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql create mode 100644 synapse/storage/databases/main/schema/delta/53/event_format_version.sql create mode 100644 synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql create mode 100644 synapse/storage/databases/main/schema/delta/53/user_ips_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/53/user_share.sql create mode 100644 synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql create mode 100644 synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql create mode 100644 synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql create mode 100644 synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql create mode 100644 synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql create mode 100644 synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql create mode 100644 synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql create mode 100644 synapse/storage/databases/main/schema/delta/54/relations.sql create mode 100644 synapse/storage/databases/main/schema/delta/54/stats.sql create mode 100644 synapse/storage/databases/main/schema/delta/54/stats2.sql create mode 100644 synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql create mode 100644 synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql create mode 100644 synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres create mode 100644 synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/event_expiry.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/event_labels.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/hidden_devices.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite create mode 100644 synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/redaction_censor.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres create mode 100644 synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/room_key_etag.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/room_retention.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/signing_keys.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/stats_separated.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py create mode 100644 synapse/storage/databases/main/schema/delta/56/user_external_ids.sql create mode 100644 synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql create mode 100644 synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql create mode 100644 synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql create mode 100644 synapse/storage/databases/main/schema/delta/57/local_current_membership.py create mode 100644 synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql create mode 100644 synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql create mode 100644 synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres create mode 100644 synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite create mode 100644 synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres create mode 100644 synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite create mode 100644 synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres create mode 100644 synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py create mode 100644 synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres create mode 100644 synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite create mode 100644 synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/11user_id_seq.py create mode 100644 synapse/storage/databases/main/schema/delta/58/12room_stats.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/12unread_messages.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/application_services.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/im.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/keys.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/presence.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/profiles.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/push.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/redactions.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/state.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/transactions.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/16/users.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres create mode 100644 synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite create mode 100644 synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql create mode 100644 synapse/storage/databases/main/schema/full_schemas/README.md create mode 100644 synapse/storage/databases/main/search.py create mode 100644 synapse/storage/databases/main/signatures.py create mode 100644 synapse/storage/databases/main/state.py create mode 100644 synapse/storage/databases/main/state_deltas.py create mode 100644 synapse/storage/databases/main/stats.py create mode 100644 synapse/storage/databases/main/stream.py create mode 100644 synapse/storage/databases/main/tags.py create mode 100644 synapse/storage/databases/main/transactions.py create mode 100644 synapse/storage/databases/main/ui_auth.py create mode 100644 synapse/storage/databases/main/user_directory.py create mode 100644 synapse/storage/databases/main/user_erasure_store.py create mode 100644 synapse/storage/databases/state/__init__.py create mode 100644 synapse/storage/databases/state/bg_updates.py create mode 100644 synapse/storage/databases/state/schema/delta/23/drop_state_index.sql create mode 100644 synapse/storage/databases/state/schema/delta/30/state_stream.sql create mode 100644 synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql create mode 100644 synapse/storage/databases/state/schema/delta/35/add_state_index.sql create mode 100644 synapse/storage/databases/state/schema/delta/35/state.sql create mode 100644 synapse/storage/databases/state/schema/delta/35/state_dedupe.sql create mode 100644 synapse/storage/databases/state/schema/delta/47/state_group_seq.py create mode 100644 synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql create mode 100644 synapse/storage/databases/state/schema/full_schemas/54/full.sql create mode 100644 synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres create mode 100644 synapse/storage/databases/state/store.py (limited to 'tests') diff --git a/changelog.d/8033.misc b/changelog.d/8033.misc new file mode 100644 index 0000000000..7a9782d14b --- /dev/null +++ b/changelog.d/8033.misc @@ -0,0 +1 @@ +Rename storage layer objects to be more sensible. diff --git a/docs/user_directory.md b/docs/user_directory.md index 37dc71e751..872fc21979 100644 --- a/docs/user_directory.md +++ b/docs/user_directory.md @@ -7,6 +7,6 @@ who are present in a publicly viewable room present on the server. The directory info is stored in various tables, which can (typically after DB corruption) get stale or out of sync. If this happens, for now the -solution to fix it is to execute the SQL [here](../synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql) +solution to fix it is to execute the SQL [here](../synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql) and then restart synapse. This should then start a background task to flush the current tables and regenerate the directory. diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 94aa8758b4..56365e2b58 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -40,7 +40,7 @@ class MockHomeserver(HomeServer): config.server_name, reactor=reactor, config=config, **kwargs ) - self.version_string = "Synapse/"+get_version_string(synapse) + self.version_string = "Synapse/" + get_version_string(synapse) if __name__ == "__main__": @@ -86,7 +86,7 @@ if __name__ == "__main__": store = hs.get_datastore() async def run_background_updates(): - await store.db.updates.run_background_updates(sleep=False) + await store.db_pool.updates.run_background_updates(sleep=False) # Stop the reactor to exit the script once every background update is run. reactor.stop() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index bee525197f..ae5e1810fc 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -35,31 +35,29 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) -from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore -from synapse.storage.data_stores.main.deviceinbox import ( - DeviceInboxBackgroundUpdateStore, -) -from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore -from synapse.storage.data_stores.main.events_bg_updates import ( +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore +from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore +from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore +from synapse.storage.databases.main.events_bg_updates import ( EventsBackgroundUpdatesStore, ) -from synapse.storage.data_stores.main.media_repository import ( +from synapse.storage.databases.main.media_repository import ( MediaRepositoryBackgroundUpdateStore, ) -from synapse.storage.data_stores.main.registration import ( +from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, ) -from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore -from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore -from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore -from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore -from synapse.storage.data_stores.main.stats import StatsStore -from synapse.storage.data_stores.main.user_directory import ( +from synapse.storage.databases.main.room import RoomBackgroundUpdateStore +from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore +from synapse.storage.databases.main.search import SearchBackgroundUpdateStore +from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore +from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) -from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.database import Database, make_conn +from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -175,14 +173,14 @@ class Store( StatsStore, ): def execute(self, f, *args, **kwargs): - return self.db.runInteraction(f.__name__, f, *args, **kwargs) + return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) def execute_sql(self, sql, *args): def r(txn): txn.execute(sql, args) return txn.fetchall() - return self.db.runInteraction("execute_sql", r) + return self.db_pool.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( @@ -227,7 +225,7 @@ class Porter(object): async def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = await self.postgres_store.db.simple_select_one( + row = await self.postgres_store.db_pool.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -244,7 +242,7 @@ class Porter(object): ) = await self._setup_sent_transactions() backward_chunk = 0 else: - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -274,7 +272,7 @@ class Porter(object): await self.postgres_store.execute(delete_all) - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -318,7 +316,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -359,7 +357,7 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = await self.sqlite_store.db.runInteraction( + headers, frows, brows = await self.sqlite_store.db_pool.runInteraction( "select", r ) @@ -375,7 +373,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store.db.simple_update_one_txn( + self.postgres_store.db_pool.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -413,7 +411,7 @@ class Porter(object): return headers, rows - headers, rows = await self.sqlite_store.db.runInteraction("select", r) + headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) if rows: forward_chunk = rows[-1][0] + 1 @@ -451,7 +449,7 @@ class Porter(object): ], ) - self.postgres_store.db.simple_update_one_txn( + self.postgres_store.db_pool.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -494,7 +492,7 @@ class Porter(object): db_conn, allow_outdated_version=allow_outdated_version ) prepare_database(db_conn, engine, config=self.hs_config) - store = Store(Database(hs, db_config, engine), db_conn, hs) + store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) db_conn.commit() return store @@ -502,7 +500,7 @@ class Porter(object): async def run_background_updates_on_postgres(self): # Manually apply all background updates on the PostgreSQL database. postgres_ready = ( - await self.postgres_store.db.updates.has_completed_background_updates() + await self.postgres_store.db_pool.updates.has_completed_background_updates() ) if not postgres_ready: @@ -511,9 +509,9 @@ class Porter(object): self.progress.set_state("Running background updates on PostgreSQL") while not postgres_ready: - await self.postgres_store.db.updates.do_next_background_update(100) + await self.postgres_store.db_pool.updates.do_next_background_update(100) postgres_ready = await ( - self.postgres_store.db.updates.has_completed_background_updates() + self.postgres_store.db_pool.updates.has_completed_background_updates() ) async def run(self): @@ -534,7 +532,7 @@ class Porter(object): # Check if all background updates are done, abort if not. updates_complete = ( - await self.sqlite_store.db.updates.has_completed_background_updates() + await self.sqlite_store.db_pool.updates.has_completed_background_updates() ) if not updates_complete: end_error = ( @@ -576,22 +574,24 @@ class Porter(object): ) try: - await self.postgres_store.db.runInteraction("alter_table", alter_table) + await self.postgres_store.db_pool.runInteraction( + "alter_table", alter_table + ) except Exception: # On Error Resume Next pass - await self.postgres_store.db.runInteraction( + await self.postgres_store.db_pool.runInteraction( "create_port_table", create_port_table ) # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = await self.sqlite_store.db.simple_select_onecol( + sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = await self.postgres_store.db.simple_select_onecol( + postgres_tables = await self.postgres_store.db_pool.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -692,7 +692,7 @@ class Porter(object): return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = await self.sqlite_store.db.runInteraction("select", r) + headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -725,7 +725,7 @@ class Porter(object): next_chunk = await self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", @@ -794,14 +794,14 @@ class Porter(object): next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) + return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) def _setup_user_id_seq(self): def r(txn): next_id = find_max_generated_user_id_localpart(txn) + 1 txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.db.runInteraction("setup_user_id_seq", r) + return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) ############################################## diff --git a/synapse/app/_base.py b/synapse/app/_base.py index fa40c68f53..2b2cd795e0 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -268,7 +268,7 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): # It is now safe to start your Synapse. hs.start_listening(listeners) - hs.get_datastore().db.start_profiling() + hs.get_datastore().db_pool.start_profiling() hs.get_pusherpool().start() setup_sentry(hs) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index c478df53be..1a16d0b9f8 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -125,15 +125,15 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer -from synapse.storage.data_stores.main.censor_events import CensorEventsStore -from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore -from synapse.storage.data_stores.main.monthly_active_users import ( +from synapse.storage.databases.main.censor_events import CensorEventsStore +from synapse.storage.databases.main.media_repository import MediaRepositoryStore +from synapse.storage.databases.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) -from synapse.storage.data_stores.main.presence import UserPresenceState -from synapse.storage.data_stores.main.search import SearchWorkerStore -from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore -from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.storage.databases.main.presence import UserPresenceState +from synapse.storage.databases.main.search import SearchWorkerStore +from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore +from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b011e00b4b..d87a77718e 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -441,7 +441,7 @@ def setup(config_options): _base.start(hs, config.listeners) - hs.get_datastore().db.updates.start_doing_background_updates() + hs.get_datastore().db_pool.updates.start_doing_background_updates() except Exception: # Print the exception and bail out. print("Error during startup:", file=sys.stderr) @@ -551,8 +551,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): # # This only reports info about the *main* database. - stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ - stats["database_server_version"] = hs.get_datastore().db.engine.server_version + stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: diff --git a/synapse/config/database.py b/synapse/config/database.py index 62bccd9ef5..8a18a9ca2a 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -100,7 +100,10 @@ class DatabaseConnectionConfig: self.name = name self.config = db_config - self.data_stores = data_stores + + # The `data_stores` config is actually talking about `databases` (we + # changed the name). + self.databases = data_stores class DatabaseConfig(Config): diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index cca93e3a46..afecafe15c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,7 +23,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import StateMap if TYPE_CHECKING: - from synapse.storage.data_stores.main import DataStore + from synapse.storage.databases.main import DataStore @attr.s(slots=True) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0d7d1adcea..b3764dedae 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -71,7 +71,7 @@ from synapse.replication.http.federation import ( ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e451d6dc86..43901d0934 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -45,7 +45,7 @@ from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import ( Collection, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b3a3bb8c3f..5387b3724f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -38,7 +38,7 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateHandler -from synapse.storage.data_stores.main import DataStore +from synapse.storage.databases.main import DataStore from synapse.storage.presence import UserPresenceState from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer @@ -319,7 +319,7 @@ class PresenceHandler(BasePresenceHandler): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.store.db.is_running(): + if not self.store.db_pool.is_running(): return logger.info( diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index a7849cefa5..8201849951 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -219,7 +219,7 @@ class ModuleApi(object): Returns: Deferred[object]: result of func """ - return self._store.db.runInteraction(desc, func, *args, **kwargs) + return self._store.db_pool.runInteraction(desc, func, *args, **kwargs) def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index f9e2533e96..60f2e1245f 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -16,8 +16,8 @@ import logging from typing import Optional -from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) class BaseSlavedStore(CacheInvalidationWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = MultiWriterIdGenerator( diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 525b94fd87..154f0e687c 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -17,13 +17,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore -from synapse.storage.data_stores.main.tags import TagsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.databases.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data", diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index a67fbeffb7..0f8d7037bd 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.appservice import ( +from synapse.storage.databases.main.appservice import ( ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 1a38f53dfb..60dd3f6701 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches.descriptors import Cache from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index a8a16dbc71..ee7f69a918 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -16,14 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ToDeviceStream -from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_inbox", "stream_id" diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 9d8067342f..722f3745e9 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -16,14 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream -from synapse.storage.data_stores.main.devices import DeviceWorkerStore -from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.devices import DeviceWorkerStore +from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 8b9717c46f..1945bcf9a8 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.directory import DirectoryWorkerStore +from synapse.storage.databases.main.directory import DirectoryWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 1a1a50a24f..da1cc836cf 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -15,18 +15,18 @@ # limitations under the License. import logging -from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore -from synapse.storage.data_stores.main.event_push_actions import ( +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore +from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.relations import RelationsWorkerStore -from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore -from synapse.storage.data_stores.main.signatures import SignatureWorkerStore -from synapse.storage.data_stores.main.state import StateGroupWorkerStore -from synapse.storage.data_stores.main.stream import StreamWorkerStore -from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore -from synapse.storage.database import Database +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.relations import RelationsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.storage.databases.main.stream import StreamWorkerStore +from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore @@ -55,11 +55,11 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedEventStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index bcb0688954..2562b6fc38 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.filtering import FilteringStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.filtering import FilteringStore from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 5d210fa3a1..3291558c7a 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -16,13 +16,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import GroupServerStream -from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index 3def367ae9..961579751c 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.keys import KeyStore +from synapse.storage.databases.main.keys import KeyStore # KeyStore isn't really safe to use from a worker, but for now we do so and hope that # the races it creates aren't too bad. diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 2938cb8e43..a912c04360 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -15,8 +15,8 @@ from synapse.replication.tcp.streams import PresenceStream from synapse.storage import DataStore -from synapse.storage.data_stores.main.presence import PresenceStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.presence import PresenceStore from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore @@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py index 28c508aad3..f85b20a071 100644 --- a/synapse/replication/slave/storage/profile.py +++ b/synapse/replication/slave/storage/profile.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.storage.data_stores.main.profile import ProfileWorkerStore +from synapse.storage.databases.main.profile import ProfileWorkerStore class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 23ec1c5b11..590187df46 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,7 +15,7 @@ # limitations under the License. from synapse.replication.tcp.streams import PushRulesStream -from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore +from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from .events import SlavedEventStore diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index ff449f3658..63300e5da6 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -15,15 +15,15 @@ # limitations under the License. from synapse.replication.tcp.streams import PushersStream -from synapse.storage.data_stores.main.pusher import PusherWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.pusher import PusherWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 6982686eb5..17ba1f22ac 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -15,15 +15,15 @@ # limitations under the License. from synapse.replication.tcp.streams import ReceiptsStream -from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index 4b8553e250..a40f064e2b 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.registration import RegistrationWorkerStore +from synapse.storage.databases.main.registration import RegistrationWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 8710207ada..427c81772b 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -14,15 +14,15 @@ # limitations under the License. from synapse.replication.tcp.streams import PublicRoomsStream -from synapse.storage.data_stores.main.room import RoomWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.room import RoomWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index ac88e6b8c3..2091ac0df6 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.transactions import TransactionStore +from synapse.storage.databases.main.transactions import TransactionStore from ._base import BaseSlavedStore diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index a8364d9793..7c292ef3f9 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -31,7 +31,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, historical_admin_path_patterns, ) -from synapse.storage.data_stores.main.room import RoomSortOrder +from synapse.storage.databases.main.room import RoomSortOrder from synapse.types import RoomAlias, RoomID, UserID, create_requester logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index e12f65a206..f4768a9e8b 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -586,7 +586,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("Running url preview cache expiry") - if not (await self.store.db.updates.has_completed_background_updates()): + if not (await self.store.db_pool.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return diff --git a/synapse/server.py b/synapse/server.py index 8e41112530..81d7f26f9c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -105,7 +105,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import DataStore, DataStores, Storage +from synapse.storage import Databases, DataStore, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -280,7 +280,7 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") self.start_time = int(self.get_clock().time()) - self.datastores = DataStores(self.DATASTORE_CLASS, self) + self.datastores = Databases(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") def setup_master(self): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 25ccef5aa5..a1d3884667 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -28,7 +28,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo from synapse.types import StateMap from synapse.util import Clock diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index ec89f645d4..5ef3853559 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -17,18 +17,19 @@ """ The storage layer is split up into multiple parts to allow Synapse to run against different configurations of databases (e.g. single or multiple -databases). The `Database` class represents a single physical database. The -`data_stores` are classes that talk directly to a `Database` instance and have -associated schemas, background updates, etc. On top of those there are classes -that provide high level interfaces that combine calls to multiple `data_stores`. +databases). The `DatabasePool` class represents connections to a single physical +database. The `databases` are classes that talk directly to a `DatabasePool` +instance and have associated schemas, background updates, etc. On top of those +there are classes that provide high level interfaces that combine calls to +multiple `databases`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ -from synapse.storage.data_stores import DataStores -from synapse.storage.data_stores.main import DataStore +from synapse.storage.databases import Databases +from synapse.storage.databases.main import DataStore from synapse.storage.persist_events import EventsPersistenceStorage from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.state import StateGroupStorage @@ -40,7 +41,7 @@ class Storage(object): """The high level interfaces for talking to various storage layers. """ - def __init__(self, hs, stores: DataStores): + def __init__(self, hs, stores: Databases): # We include the main data store here mainly so that we don't have to # rewrite all the existing code to split it into high vs low level # interfaces. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 985a042869..ca800df831 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -23,7 +23,7 @@ from canonicaljson import json from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.types import Collection, get_domain_from_id logger = logging.getLogger(__name__) @@ -37,11 +37,11 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine - self.db = database + self.db_pool = database self.rand = random.SystemRandom() def process_replication_rows(self, stream_name, instance_name, token, rows): diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 018826ef69..f43463df53 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -88,7 +88,7 @@ class BackgroundUpdater(object): def __init__(self, hs, database): self._clock = hs.get_clock() - self.db = database + self.db_pool = database # if a background update is currently running, its name. self._current_background_update = None # type: Optional[str] @@ -139,7 +139,7 @@ class BackgroundUpdater(object): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = await self.db.simple_select_onecol( + updates = await self.db_pool.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -160,7 +160,7 @@ class BackgroundUpdater(object): if update_name == self._current_background_update: return False - update_exists = await self.db.simple_select_one_onecol( + update_exists = await self.db_pool.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="1", @@ -189,10 +189,10 @@ class BackgroundUpdater(object): ORDER BY ordering, update_name """ ) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) if not self._current_background_update: - all_pending_updates = await self.db.runInteraction( + all_pending_updates = await self.db_pool.runInteraction( "background_updates", get_background_updates_txn, ) if not all_pending_updates: @@ -243,7 +243,7 @@ class BackgroundUpdater(object): else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = await self.db.simple_select_one_onecol( + progress_json = await self.db_pool.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", @@ -402,7 +402,7 @@ class BackgroundUpdater(object): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.db.engine, engines.PostgresEngine): + if isinstance(self.db_pool.engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None @@ -413,7 +413,7 @@ class BackgroundUpdater(object): def updater(progress, batch_size): if runner is not None: logger.info("Adding index %s to %s", index_name, table) - yield self.db.runWithConnection(runner) + yield self.db_pool.runWithConnection(runner) yield self._end_background_update(update_name) return 1 @@ -433,7 +433,7 @@ class BackgroundUpdater(object): % update_name ) self._current_background_update = None - return self.db.simple_delete_one( + return self.db_pool.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) @@ -445,7 +445,7 @@ class BackgroundUpdater(object): progress: The progress of the update. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "background_update_progress", self._background_update_progress_txn, update_name, @@ -463,7 +463,7 @@ class BackgroundUpdater(object): progress_json = json.dumps(progress) - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py deleted file mode 100644 index 599ee470d4..0000000000 --- a/synapse/storage/data_stores/__init__.py +++ /dev/null @@ -1,97 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.data_stores.main.events import PersistEventsStore -from synapse.storage.data_stores.state import StateGroupDataStore -from synapse.storage.database import Database, make_conn -from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database - -logger = logging.getLogger(__name__) - - -class DataStores(object): - """The various data stores. - - These are low level interfaces to physical databases. - - Attributes: - main (DataStore) - """ - - def __init__(self, main_store_class, hs): - # Note we pass in the main store class here as workers use a different main - # store. - - self.databases = [] - self.main = None - self.state = None - self.persist_events = None - - for database_config in hs.config.database.databases: - db_name = database_config.name - engine = create_engine(database_config.config) - - with make_conn(database_config, engine) as db_conn: - logger.info("Preparing database %r...", db_name) - - engine.check_database(db_conn) - prepare_database( - db_conn, engine, hs.config, data_stores=database_config.data_stores, - ) - - database = Database(hs, database_config, engine) - - if "main" in database_config.data_stores: - logger.info("Starting 'main' data store") - - # Sanity check we don't try and configure the main store on - # multiple databases. - if self.main: - raise Exception("'main' data store already configured") - - self.main = main_store_class(database, db_conn, hs) - - # If we're on a process that can persist events also - # instantiate a `PersistEventsStore` - if hs.config.worker.writers.events == hs.get_instance_name(): - self.persist_events = PersistEventsStore( - hs, database, self.main - ) - - if "state" in database_config.data_stores: - logger.info("Starting 'state' data store") - - # Sanity check we don't try and configure the state store on - # multiple databases. - if self.state: - raise Exception("'state' data store already configured") - - self.state = StateGroupDataStore(database, db_conn, hs) - - db_conn.commit() - - self.databases.append(database) - - logger.info("Database %r prepared", db_name) - - # Sanity check that we have actually configured all the required stores. - if not self.main: - raise Exception("No 'main' data store configured") - - if not self.state: - raise Exception("No 'main' data store configured") diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py deleted file mode 100644 index 932458f651..0000000000 --- a/synapse/storage/data_stores/main/__init__.py +++ /dev/null @@ -1,592 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import calendar -import logging -import time - -from synapse.api.constants import PresenceState -from synapse.config.homeserver import HomeServerConfig -from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import ( - IdGenerator, - MultiWriterIdGenerator, - StreamIdGenerator, -) -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from .account_data import AccountDataStore -from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore -from .cache import CacheInvalidationWorkerStore -from .censor_events import CensorEventsStore -from .client_ips import ClientIpStore -from .deviceinbox import DeviceInboxStore -from .devices import DeviceStore -from .directory import DirectoryStore -from .e2e_room_keys import EndToEndRoomKeyStore -from .end_to_end_keys import EndToEndKeyStore -from .event_federation import EventFederationStore -from .event_push_actions import EventPushActionsStore -from .events_bg_updates import EventsBackgroundUpdatesStore -from .filtering import FilteringStore -from .group_server import GroupServerStore -from .keys import KeyStore -from .media_repository import MediaRepositoryStore -from .metrics import ServerMetricsStore -from .monthly_active_users import MonthlyActiveUsersStore -from .openid import OpenIdStore -from .presence import PresenceStore, UserPresenceState -from .profile import ProfileStore -from .purge_events import PurgeEventsStore -from .push_rule import PushRuleStore -from .pusher import PusherStore -from .receipts import ReceiptsStore -from .registration import RegistrationStore -from .rejections import RejectionsStore -from .relations import RelationsStore -from .room import RoomStore -from .roommember import RoomMemberStore -from .search import SearchStore -from .signatures import SignatureStore -from .state import StateStore -from .stats import StatsStore -from .stream import StreamStore -from .tags import TagsStore -from .transactions import TransactionStore -from .ui_auth import UIAuthStore -from .user_directory import UserDirectoryStore -from .user_erasure_store import UserErasureStore - -logger = logging.getLogger(__name__) - - -class DataStore( - EventsBackgroundUpdatesStore, - RoomMemberStore, - RoomStore, - RegistrationStore, - StreamStore, - ProfileStore, - PresenceStore, - TransactionStore, - DirectoryStore, - KeyStore, - StateStore, - SignatureStore, - ApplicationServiceStore, - PurgeEventsStore, - EventFederationStore, - MediaRepositoryStore, - RejectionsStore, - FilteringStore, - PusherStore, - PushRuleStore, - ApplicationServiceTransactionStore, - ReceiptsStore, - EndToEndKeyStore, - EndToEndRoomKeyStore, - SearchStore, - TagsStore, - AccountDataStore, - EventPushActionsStore, - OpenIdStore, - ClientIpStore, - DeviceStore, - DeviceInboxStore, - UserDirectoryStore, - GroupServerStore, - UserErasureStore, - MonthlyActiveUsersStore, - StatsStore, - RelationsStore, - CensorEventsStore, - UIAuthStore, - CacheInvalidationWorkerStore, - ServerMetricsStore, -): - def __init__(self, database: Database, db_conn, hs): - self.hs = hs - self._clock = hs.get_clock() - self.database_engine = database.engine - - self._presence_id_gen = StreamIdGenerator( - db_conn, "presence_stream", "stream_id" - ) - self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_inbox", "stream_id" - ) - self._public_room_id_gen = StreamIdGenerator( - db_conn, "public_room_list_stream", "stream_id" - ) - self._device_list_id_gen = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ], - ) - self._cross_signing_id_gen = StreamIdGenerator( - db_conn, "e2e_cross_signing_keys", "stream_id" - ) - - self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") - self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") - self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") - self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) - self._group_updates_id_gen = StreamIdGenerator( - db_conn, "local_group_updates", "stream_id" - ) - - if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = MultiWriterIdGenerator( - db_conn, - database, - instance_name="master", - table="cache_invalidation_stream_by_instance", - instance_column="instance_name", - id_column="stream_id", - sequence_name="cache_invalidation_stream_seq", - ) - else: - self._cache_id_gen = None - - super(DataStore, self).__init__(database, db_conn, hs) - - self._presence_on_startup = self._get_active_presence(db_conn) - - presence_cache_prefill, min_presence_val = self.db.get_cache_dict( - db_conn, - "presence_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._presence_id_gen.get_current_token(), - ) - self.presence_stream_cache = StreamChangeCache( - "PresenceStreamChangeCache", - min_presence_val, - prefilled_cache=presence_cache_prefill, - ) - - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( - db_conn, - "device_inbox", - entity_column="user_id", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - min_device_inbox_id, - prefilled_cache=device_inbox_prefill, - ) - # The federation outbox and the local device inbox uses the same - # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( - db_conn, - "device_federation_outbox", - entity_column="destination", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - min_device_outbox_id, - prefilled_cache=device_outbox_prefill, - ) - - device_list_max = self._device_list_id_gen.get_current_token() - self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max - ) - self._user_signature_stream_cache = StreamChangeCache( - "UserSignatureStreamChangeCache", device_list_max - ) - self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max - ) - - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) - - _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( - db_conn, - "local_group_updates", - entity_column="user_id", - stream_column="stream_id", - max_value=self._group_updates_id_gen.get_current_token(), - limit=1000, - ) - self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", - min_group_updates_id, - prefilled_cache=_group_updates_prefill, - ) - - self._stream_order_on_start = self.get_room_max_stream_ordering() - self._min_stream_order_on_start = self.get_room_min_stream_ordering() - - # Used in _generate_user_daily_visits to keep track of progress - self._last_user_visit_update = self._get_start_of_day() - - def take_presence_startup_info(self): - active_on_startup = self._presence_on_startup - self._presence_on_startup = None - return active_on_startup - - def _get_active_presence(self, db_conn): - """Fetch non-offline presence from the database so that we can register - the appropriate time outs. - """ - - sql = ( - "SELECT user_id, state, last_active_ts, last_federation_update_ts," - " last_user_sync_ts, status_msg, currently_active FROM presence_stream" - " WHERE state != ?" - ) - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.db.cursor_to_dict(txn) - txn.close() - - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return [UserPresenceState(**row) for row in rows] - - def count_daily_users(self): - """ - Counts the number of users who used this homeserver in the last 24 hours. - """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.db.runInteraction("count_daily_users", self._count_users, yesterday) - - def count_monthly_users(self): - """ - Counts the number of users who used this homeserver in the last 30 days. - Note this method is intended for phonehome metrics only and is different - from the mau figure in synapse.storage.monthly_active_users which, - amongst other things, includes a 3 day grace period before a user counts. - """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.db.runInteraction( - "count_monthly_users", self._count_users, thirty_days_ago - ) - - def _count_users(self, txn, time_from): - """ - Returns number of users seen in the past time_from period - """ - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() - return count - - def count_r30_users(self): - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns counts globaly for a given user as well as breaking - by platform - """ - - def _count_r30_users(txn): - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - (count,) = txn.fetchone() - results["all"] = count - - return results - - return self.db.runInteraction("count_r30_users", _count_r30_users) - - def _get_start_of_day(self): - """ - Returns millisecond unixtime for start of UTC day. - """ - now = time.gmtime() - today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) - return today_start * 1000 - - def generate_user_daily_visits(self): - """ - Generates daily visit data for use in cohort/ retention analysis - """ - - def _generate_user_daily_visits(txn): - logger.info("Calling _generate_user_daily_visits") - today_start = self._get_start_of_day() - a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self.clock.time_msec() - - sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? - FROM user_ips AS u - LEFT JOIN ( - SELECT user_id, device_id, timestamp FROM user_daily_visits - WHERE timestamp = ? - ) udv - ON u.user_id = udv.user_id AND u.device_id=udv.device_id - INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? - AND udv.timestamp IS NULL AND users.is_guest=0 - AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id - """ - - # This means that the day has rolled over but there could still - # be entries from the previous day. There is an edge case - # where if the user logs in at 23:59 and overwrites their - # last_seen at 00:01 then they will not be counted in the - # previous day's stats - it is important that the query is run - # often to minimise this case. - if today_start > self._last_user_visit_update: - yesterday_start = today_start - a_day_in_milliseconds - txn.execute( - sql, - ( - yesterday_start, - yesterday_start, - self._last_user_visit_update, - today_start, - ), - ) - self._last_user_visit_update = today_start - - txn.execute( - sql, (today_start, today_start, self._last_user_visit_update, now) - ) - # Update _last_user_visit_update to now. The reason to do this - # rather just clamping to the beginning of the day is to limit - # the size of the join - meaning that the query can be run more - # frequently - self._last_user_visit_update = now - - return self.db.runInteraction( - "generate_user_daily_visits", _generate_user_daily_visits - ) - - def get_users(self): - """Function to retrieve a list of users in users table. - - Args: - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.db.simple_select_list( - table="users", - keyvalues={}, - retcols=[ - "name", - "password_hash", - "is_guest", - "admin", - "user_type", - "deactivated", - ], - desc="get_users", - ) - - def get_users_paginate( - self, start, limit, name=None, guests=True, deactivated=False - ): - """Function to retrieve a paginated list of users from - users list. This will return a json list of users and the - total number of users matching the filter criteria. - - Args: - start (int): start number to begin the query from - limit (int): number of rows to retrieve - name (string): filter for user names - guests (bool): whether to in include guest users - deactivated (bool): whether to include deactivated users - Returns: - defer.Deferred: resolves to list[dict[str, Any]], int - """ - - def get_users_paginate_txn(txn): - filters = [] - args = [] - - if name: - filters.append("name LIKE ?") - args.append("%" + name + "%") - - if not guests: - filters.append("is_guest = 0") - - if not deactivated: - filters.append("deactivated = 0") - - where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - - sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause) - txn.execute(sql, args) - count = txn.fetchone()[0] - - args = [self.hs.config.server_name] + args + [limit, start] - sql = """ - SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url - FROM users as u - LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? - {} - ORDER BY u.name LIMIT ? OFFSET ? - """.format( - where_clause - ) - txn.execute(sql, args) - users = self.db.cursor_to_dict(txn) - return users, count - - return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn) - - def search_users(self, term): - """Function to search users list for one or more users with - the matched term. - - Args: - term (str): search term - col (str): column to query term should be matched to - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.db.simple_search_list( - table="users", - term=term, - col="name", - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], - desc="search_users", - ) - - -def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig): - """Called before upgrading an existing database to check that it is broadly sane - compared with the configuration. - """ - domain = config.server_name - - sql = database_engine.convert_param_style( - "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" - ) - pat = "%:" + domain - cur.execute(sql, (pat,)) - num_not_matching = cur.fetchall()[0][0] - if num_not_matching == 0: - return - - raise Exception( - "Found users in database not native to %s!\n" - "You cannot changed a synapse server_name after it's been configured" - % (domain,) - ) - - -__all__ = ["DataStore", "check_database_before_upgrade"] diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py deleted file mode 100644 index 33cc372dfd..0000000000 --- a/synapse/storage/data_stores/main/account_data.py +++ /dev/null @@ -1,430 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import logging -from typing import List, Tuple - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import Database -from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.util.caches.stream_change_cache import StreamChangeCache - -logger = logging.getLogger(__name__) - - -class AccountDataWorkerStore(SQLBaseStore): - """This is an abstract base class where subclasses must implement - `get_max_account_data_stream_id` which can be called in the initializer. - """ - - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - - def __init__(self, database: Database, db_conn, hs): - account_max = self.get_max_account_data_stream_id() - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", account_max - ) - - super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) - - @abc.abstractmethod - def get_max_account_data_stream_id(self): - """Get the current max stream ID for account data stream - - Returns: - int - """ - raise NotImplementedError() - - @cached() - def get_account_data_for_user(self, user_id): - """Get all the client account_data for a user. - - Args: - user_id(str): The user to get the account_data for. - Returns: - A deferred pair of a dict of global account_data and a dict - mapping from room_id string to per room account_data dicts. - """ - - def get_account_data_for_user_txn(txn): - rows = self.db.simple_select_list_txn( - txn, - "account_data", - {"user_id": user_id}, - ["account_data_type", "content"], - ) - - global_account_data = { - row["account_data_type"]: db_to_json(row["content"]) for row in rows - } - - rows = self.db.simple_select_list_txn( - txn, - "room_account_data", - {"user_id": user_id}, - ["room_id", "account_data_type", "content"], - ) - - by_room = {} - for row in rows: - room_data = by_room.setdefault(row["room_id"], {}) - room_data[row["account_data_type"]] = db_to_json(row["content"]) - - return global_account_data, by_room - - return self.db.runInteraction( - "get_account_data_for_user", get_account_data_for_user_txn - ) - - @cachedInlineCallbacks(num_args=2, max_entries=5000) - def get_global_account_data_by_type_for_user(self, data_type, user_id): - """ - Returns: - Deferred: A dict - """ - result = yield self.db.simple_select_one_onecol( - table="account_data", - keyvalues={"user_id": user_id, "account_data_type": data_type}, - retcol="content", - desc="get_global_account_data_by_type_for_user", - allow_none=True, - ) - - if result: - return db_to_json(result) - else: - return None - - @cached(num_args=2) - def get_account_data_for_room(self, user_id, room_id): - """Get all the client account_data for a user for a room. - - Args: - user_id(str): The user to get the account_data for. - room_id(str): The room to get the account_data for. - Returns: - A deferred dict of the room account_data - """ - - def get_account_data_for_room_txn(txn): - rows = self.db.simple_select_list_txn( - txn, - "room_account_data", - {"user_id": user_id, "room_id": room_id}, - ["account_data_type", "content"], - ) - - return { - row["account_data_type"]: db_to_json(row["content"]) for row in rows - } - - return self.db.runInteraction( - "get_account_data_for_room", get_account_data_for_room_txn - ) - - @cached(num_args=3, max_entries=5000) - def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): - """Get the client account_data of given type for a user for a room. - - Args: - user_id(str): The user to get the account_data for. - room_id(str): The room to get the account_data for. - account_data_type (str): The account data type to get. - Returns: - A deferred of the room account_data for that type, or None if - there isn't any set. - """ - - def get_account_data_for_room_and_type_txn(txn): - content_json = self.db.simple_select_one_onecol_txn( - txn, - table="room_account_data", - keyvalues={ - "user_id": user_id, - "room_id": room_id, - "account_data_type": account_data_type, - }, - retcol="content", - allow_none=True, - ) - - return db_to_json(content_json) if content_json else None - - return self.db.runInteraction( - "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn - ) - - async def get_updated_global_account_data( - self, last_id: int, current_id: int, limit: int - ) -> List[Tuple[int, str, str]]: - """Get the global account_data that has changed, for the account_data stream - - Args: - last_id: the last stream_id from the previous batch. - current_id: the maximum stream_id to return up to - limit: the maximum number of rows to return - - Returns: - A list of tuples of stream_id int, user_id string, - and type string. - """ - if last_id == current_id: - return [] - - def get_updated_global_account_data_txn(txn): - sql = ( - "SELECT stream_id, user_id, account_data_type" - " FROM account_data WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() - - return await self.db.runInteraction( - "get_updated_global_account_data", get_updated_global_account_data_txn - ) - - async def get_updated_room_account_data( - self, last_id: int, current_id: int, limit: int - ) -> List[Tuple[int, str, str, str]]: - """Get the global account_data that has changed, for the account_data stream - - Args: - last_id: the last stream_id from the previous batch. - current_id: the maximum stream_id to return up to - limit: the maximum number of rows to return - - Returns: - A list of tuples of stream_id int, user_id string, - room_id string and type string. - """ - if last_id == current_id: - return [] - - def get_updated_room_account_data_txn(txn): - sql = ( - "SELECT stream_id, user_id, room_id, account_data_type" - " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() - - return await self.db.runInteraction( - "get_updated_room_account_data", get_updated_room_account_data_txn - ) - - def get_updated_account_data_for_user(self, user_id, stream_id): - """Get all the client account_data for a that's changed for a user - - Args: - user_id(str): The user to get the account_data for. - stream_id(int): The point in the stream since which to get updates - Returns: - A deferred pair of a dict of global account_data and a dict - mapping from room_id string to per room account_data dicts. - """ - - def get_updated_account_data_for_user_txn(txn): - sql = ( - "SELECT account_data_type, content FROM account_data" - " WHERE user_id = ? AND stream_id > ?" - ) - - txn.execute(sql, (user_id, stream_id)) - - global_account_data = {row[0]: db_to_json(row[1]) for row in txn} - - sql = ( - "SELECT room_id, account_data_type, content FROM room_account_data" - " WHERE user_id = ? AND stream_id > ?" - ) - - txn.execute(sql, (user_id, stream_id)) - - account_data_by_room = {} - for row in txn: - room_account_data = account_data_by_room.setdefault(row[0], {}) - room_account_data[row[1]] = db_to_json(row[2]) - - return global_account_data, account_data_by_room - - changed = self._account_data_stream_cache.has_entity_changed( - user_id, int(stream_id) - ) - if not changed: - return defer.succeed(({}, {})) - - return self.db.runInteraction( - "get_updated_account_data_for_user", get_updated_account_data_for_user_txn - ) - - @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) - def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): - ignored_account_data = yield self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", - ignorer_user_id, - on_invalidate=cache_context.invalidate, - ) - if not ignored_account_data: - return False - - return ignored_user_id in ignored_account_data.get("ignored_users", {}) - - -class AccountDataStore(AccountDataWorkerStore): - def __init__(self, database: Database, db_conn, hs): - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "account_data_max_stream_id", - "stream_id", - extra_tables=[ - ("room_account_data", "stream_id"), - ("room_tags_revisions", "stream_id"), - ], - ) - - super(AccountDataStore, self).__init__(database, db_conn, hs) - - def get_max_account_data_stream_id(self): - """Get the current max stream id for the private user data stream - - Returns: - A deferred int. - """ - return self._account_data_id_gen.get_current_token() - - @defer.inlineCallbacks - def add_account_data_to_room(self, user_id, room_id, account_data_type, content): - """Add some account_data to a room for a user. - Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. - Returns: - A deferred that completes once the account_data has been added. - """ - content_json = json.dumps(content) - - with self._account_data_id_gen.get_next() as next_id: - # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so simple_upsert will - # retry if there is a conflict. - yield self.db.simple_upsert( - desc="add_room_account_data", - table="room_account_data", - keyvalues={ - "user_id": user_id, - "room_id": room_id, - "account_data_type": account_data_type, - }, - values={"stream_id": next_id, "content": content_json}, - lock=False, - ) - - # it's theoretically possible for the above to succeed and the - # below to fail - in which case we might reuse a stream id on - # restart, and the above update might not get propagated. That - # doesn't sound any worse than the whole update getting lost, - # which is what would happen if we combined the two into one - # transaction. - yield self._update_max_stream_id(next_id) - - self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) - self.get_account_data_for_room.invalidate((user_id, room_id)) - self.get_account_data_for_room_and_type.prefill( - (user_id, room_id, account_data_type), content - ) - - result = self._account_data_id_gen.get_current_token() - return result - - @defer.inlineCallbacks - def add_account_data_for_user(self, user_id, account_data_type, content): - """Add some account_data to a room for a user. - Args: - user_id(str): The user to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. - Returns: - A deferred that completes once the account_data has been added. - """ - content_json = json.dumps(content) - - with self._account_data_id_gen.get_next() as next_id: - # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so simple_upsert will retry if - # there is a conflict. - yield self.db.simple_upsert( - desc="add_user_account_data", - table="account_data", - keyvalues={"user_id": user_id, "account_data_type": account_data_type}, - values={"stream_id": next_id, "content": content_json}, - lock=False, - ) - - # it's theoretically possible for the above to succeed and the - # below to fail - in which case we might reuse a stream id on - # restart, and the above update might not get propagated. That - # doesn't sound any worse than the whole update getting lost, - # which is what would happen if we combined the two into one - # transaction. - # - # Note: This is only here for backwards compat to allow admins to - # roll back to a previous Synapse version. Next time we update the - # database version we can remove this table. - yield self._update_max_stream_id(next_id) - - self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) - self.get_global_account_data_by_type_for_user.invalidate( - (account_data_type, user_id) - ) - - result = self._account_data_id_gen.get_current_token() - return result - - def _update_max_stream_id(self, next_id): - """Update the max stream_id - - Args: - next_id(int): The the revision to advance to. - """ - - # Note: This is only here for backwards compat to allow admins to - # roll back to a previous Synapse version. Next time we update the - # database version we can remove this table. - - def _update(txn): - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" - ) - txn.execute(update_max_id_sql, (next_id, next_id)) - - return self.db.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py deleted file mode 100644 index 56659fed37..0000000000 --- a/synapse/storage/data_stores/main/appservice.py +++ /dev/null @@ -1,372 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import re - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.appservice import AppServiceTransaction -from synapse.config.appservice import load_appservices -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database - -logger = logging.getLogger(__name__) - - -def _make_exclusive_regex(services_cache): - # We precompile a regex constructed from all the regexes that the AS's - # have registered for exclusive users. - exclusive_user_regexes = [ - regex.pattern - for service in services_cache - for regex in service.get_exclusive_user_regexes() - ] - if exclusive_user_regexes: - exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_regex = re.compile(exclusive_user_regex) - else: - # We handle this case specially otherwise the constructed regex - # will always match - exclusive_user_regex = None - - return exclusive_user_regex - - -class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - self.services_cache = load_appservices( - hs.hostname, hs.config.app_service_config_files - ) - self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - - super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) - - def get_app_services(self): - return self.services_cache - - def get_if_app_services_interested_in_user(self, user_id): - """Check if the user is one associated with an app service (exclusively) - """ - if self.exclusive_user_regex: - return bool(self.exclusive_user_regex.match(user_id)) - else: - return False - - def get_app_service_by_user_id(self, user_id): - """Retrieve an application service from their user ID. - - All application services have associated with them a particular user ID. - There is no distinguishing feature on the user ID which indicates it - represents an application service. This function allows you to map from - a user ID to an application service. - - Args: - user_id(str): The user ID to see if it is an application service. - Returns: - synapse.appservice.ApplicationService or None. - """ - for service in self.services_cache: - if service.sender == user_id: - return service - return None - - def get_app_service_by_token(self, token): - """Get the application service with the given appservice token. - - Args: - token (str): The application service token. - Returns: - synapse.appservice.ApplicationService or None. - """ - for service in self.services_cache: - if service.token == token: - return service - return None - - def get_app_service_by_id(self, as_id): - """Get the application service with the given appservice ID. - - Args: - as_id (str): The application service ID. - Returns: - synapse.appservice.ApplicationService or None. - """ - for service in self.services_cache: - if service.id == as_id: - return service - return None - - -class ApplicationServiceStore(ApplicationServiceWorkerStore): - # This is currently empty due to there not being any AS storage functions - # that can't be run on the workers. Since this may change in future, and - # to keep consistency with the other stores, we keep this empty class for - # now. - pass - - -class ApplicationServiceTransactionWorkerStore( - ApplicationServiceWorkerStore, EventsWorkerStore -): - @defer.inlineCallbacks - def get_appservices_by_state(self, state): - """Get a list of application services based on their state. - - Args: - state(ApplicationServiceState): The state to filter on. - Returns: - A Deferred which resolves to a list of ApplicationServices, which - may be empty. - """ - results = yield self.db.simple_select_list( - "application_services_state", {"state": state}, ["as_id"] - ) - # NB: This assumes this class is linked with ApplicationServiceStore - as_list = self.get_app_services() - services = [] - - for res in results: - for service in as_list: - if service.id == res["as_id"]: - services.append(service) - return services - - @defer.inlineCallbacks - def get_appservice_state(self, service): - """Get the application service state. - - Args: - service(ApplicationService): The service whose state to set. - Returns: - A Deferred which resolves to ApplicationServiceState. - """ - result = yield self.db.simple_select_one( - "application_services_state", - {"as_id": service.id}, - ["state"], - allow_none=True, - desc="get_appservice_state", - ) - if result: - return result.get("state") - return None - - def set_appservice_state(self, service, state): - """Set the application service state. - - Args: - service(ApplicationService): The service whose state to set. - state(ApplicationServiceState): The connectivity state to apply. - Returns: - A Deferred which resolves when the state was set successfully. - """ - return self.db.simple_upsert( - "application_services_state", {"as_id": service.id}, {"state": state} - ) - - def create_appservice_txn(self, service, events): - """Atomically creates a new transaction for this application service - with the given list of events. - - Args: - service(ApplicationService): The service who the transaction is for. - events(list): A list of events to put in the transaction. - Returns: - AppServiceTransaction: A new transaction. - """ - - def _create_appservice_txn(txn): - # work out new txn id (highest txn id for this service += 1) - # The highest id may be the last one sent (in which case it is last_txn) - # or it may be the highest in the txns list (which are waiting to be/are - # being sent) - last_txn_id = self._get_last_txn(txn, service.id) - - txn.execute( - "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", - (service.id,), - ) - highest_txn_id = txn.fetchone()[0] - if highest_txn_id is None: - highest_txn_id = 0 - - new_txn_id = max(highest_txn_id, last_txn_id) + 1 - - # Insert new txn into txn table - event_ids = json.dumps([e.event_id for e in events]) - txn.execute( - "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " - "VALUES(?,?,?)", - (service.id, new_txn_id, event_ids), - ) - return AppServiceTransaction(service=service, id=new_txn_id, events=events) - - return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) - - def complete_appservice_txn(self, txn_id, service): - """Completes an application service transaction. - - Args: - txn_id(str): The transaction ID being completed. - service(ApplicationService): The application service which was sent - this transaction. - Returns: - A Deferred which resolves if this transaction was stored - successfully. - """ - txn_id = int(txn_id) - - def _complete_appservice_txn(txn): - # Debugging query: Make sure the txn being completed is EXACTLY +1 from - # what was there before. If it isn't, we've got problems (e.g. the AS - # has probably missed some events), so whine loudly but still continue, - # since it shouldn't fail completion of the transaction. - last_txn_id = self._get_last_txn(txn, service.id) - if (last_txn_id + 1) != txn_id: - logger.error( - "appservice: Completing a transaction which has an ID > 1 from " - "the last ID sent to this AS. We've either dropped events or " - "sent it to the AS out of order. FIX ME. last_txn=%s " - "completing_txn=%s service_id=%s", - last_txn_id, - txn_id, - service.id, - ) - - # Set current txn_id for AS to 'txn_id' - self.db.simple_upsert_txn( - txn, - "application_services_state", - {"as_id": service.id}, - {"last_txn": txn_id}, - ) - - # Delete txn - self.db.simple_delete_txn( - txn, - "application_services_txns", - {"txn_id": txn_id, "as_id": service.id}, - ) - - return self.db.runInteraction( - "complete_appservice_txn", _complete_appservice_txn - ) - - @defer.inlineCallbacks - def get_oldest_unsent_txn(self, service): - """Get the oldest transaction which has not been sent for this - service. - - Args: - service(ApplicationService): The app service to get the oldest txn. - Returns: - A Deferred which resolves to an AppServiceTransaction or - None. - """ - - def _get_oldest_unsent_txn(txn): - # Monotonically increasing txn ids, so just select the smallest - # one in the txns table (we delete them when they are sent) - txn.execute( - "SELECT * FROM application_services_txns WHERE as_id=?" - " ORDER BY txn_id ASC LIMIT 1", - (service.id,), - ) - rows = self.db.cursor_to_dict(txn) - if not rows: - return None - - entry = rows[0] - - return entry - - entry = yield self.db.runInteraction( - "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn - ) - - if not entry: - return None - - event_ids = db_to_json(entry["event_ids"]) - - events = yield self.get_events_as_list(event_ids) - - return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) - - def _get_last_txn(self, txn, service_id): - txn.execute( - "SELECT last_txn FROM application_services_state WHERE as_id=?", - (service_id,), - ) - last_txn_id = txn.fetchone() - if last_txn_id is None or last_txn_id[0] is None: # no row exists - return 0 - else: - return int(last_txn_id[0]) # select 'last_txn' col - - def set_appservice_last_pos(self, pos): - def set_appservice_last_pos_txn(txn): - txn.execute( - "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) - ) - - return self.db.runInteraction( - "set_appservice_last_pos", set_appservice_last_pos_txn - ) - - @defer.inlineCallbacks - def get_new_events_for_appservice(self, current_id, limit): - """Get all new evnets""" - - def get_new_events_for_appservice_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id" - " FROM events AS e" - " WHERE" - " (SELECT stream_ordering FROM appservice_stream_position)" - " < e.stream_ordering" - " AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - " LIMIT ?" - ) - - txn.execute(sql, (current_id, limit)) - rows = txn.fetchall() - - upper_bound = current_id - if len(rows) == limit: - upper_bound = rows[-1][0] - - return upper_bound, [row[1] for row in rows] - - upper_bound, event_ids = yield self.db.runInteraction( - "get_new_events_for_appservice", get_new_events_for_appservice_txn - ) - - events = yield self.get_events_as_list(event_ids) - - return upper_bound, events - - -class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): - # This is currently empty due to there not being any AS storage functions - # that can't be run on the workers. Since this may change in future, and - # to keep consistency with the other stores, we keep this empty class for - # now. - pass diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py deleted file mode 100644 index edc3624fed..0000000000 --- a/synapse/storage/data_stores/main/cache.py +++ /dev/null @@ -1,307 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import itertools -import logging -from typing import Any, Iterable, List, Optional, Tuple - -from synapse.api.constants import EventTypes -from synapse.replication.tcp.streams import BackfillStream, CachesStream -from synapse.replication.tcp.streams.events import ( - EventsStream, - EventsStreamCurrentStateRow, - EventsStreamEventRow, -) -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine -from synapse.util.iterutils import batch_iter - -logger = logging.getLogger(__name__) - - -# This is a special cache name we use to batch multiple invalidations of caches -# based on the current state when notifying workers over replication. -CURRENT_STATE_CACHE_NAME = "cs_cache_fake" - - -class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - super().__init__(database, db_conn, hs) - - self._instance_name = hs.get_instance_name() - - async def get_all_updated_caches( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for caches replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - if last_id == current_id: - return [], current_id, False - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = """ - SELECT stream_id, cache_func, keys, invalidation_ts - FROM cache_invalidation_stream_by_instance - WHERE stream_id > ? AND instance_name = ? - ORDER BY stream_id ASC - LIMIT ? - """ - txn.execute(sql, (last_id, instance_name, limit)) - updates = [(row[0], row[1:]) for row in txn] - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db.runInteraction( - "get_all_updated_caches", get_all_updated_caches_txn - ) - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == EventsStream.NAME: - for row in rows: - self._process_event_stream_row(token, row) - elif stream_name == BackfillStream.NAME: - for row in rows: - self._invalidate_caches_for_event( - -token, - row.event_id, - row.room_id, - row.type, - row.state_key, - row.redacts, - row.relates_to, - backfilled=True, - ) - elif stream_name == CachesStream.NAME: - if self._cache_id_gen: - self._cache_id_gen.advance(instance_name, token) - - for row in rows: - if row.cache_func == CURRENT_STATE_CACHE_NAME: - if row.keys is None: - raise Exception( - "Can't send an 'invalidate all' for current state cache" - ) - - room_id = row.keys[0] - members_changed = set(row.keys[1:]) - self._invalidate_state_caches(room_id, members_changed) - else: - self._attempt_to_invalidate_cache(row.cache_func, row.keys) - - super().process_replication_rows(stream_name, instance_name, token, rows) - - def _process_event_stream_row(self, token, row): - data = row.data - - if row.type == EventsStreamEventRow.TypeId: - self._invalidate_caches_for_event( - token, - data.event_id, - data.room_id, - data.type, - data.state_key, - data.redacts, - data.relates_to, - backfilled=False, - ) - elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) - - if data.type == EventTypes.Member: - self.get_rooms_for_user_with_stream_ordering.invalidate( - (data.state_key,) - ) - else: - raise Exception("Unknown events stream row type %s" % (row.type,)) - - def _invalidate_caches_for_event( - self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): - self._invalidate_get_event_cache(event_id) - - self.get_latest_event_ids_in_room.invalidate((room_id,)) - - self.get_unread_message_count_for_user.invalidate_many((room_id,)) - self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) - - if not backfilled: - self._events_stream_cache.entity_has_changed(room_id, stream_ordering) - - if redacts: - self._invalidate_get_event_cache(redacts) - - if etype == EventTypes.Member: - self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_local_user.invalidate((state_key,)) - - if relates_to: - self.get_relations_for_event.invalidate_many((relates_to,)) - self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) - self.get_applicable_edit.invalidate((relates_to,)) - - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. - - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - cache_func = getattr(self, cache_name, None) - if not cache_func: - return - - cache_func.invalidate(keys) - await self.db.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, - cache_func.__name__, - keys, - ) - - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. - - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - txn.call_after(cache_func.invalidate, keys) - self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - - def _invalidate_all_cache_and_stream(self, txn, cache_func): - """Invalidates the entire cache and adds it to the cache stream so slaves - will know to invalidate their caches. - """ - - txn.call_after(cache_func.invalidate_all) - self._send_invalidation_to_replication(txn, cache_func.__name__, None) - - def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): - """Special case invalidation of caches based on current state. - - We special case this so that we can batch the cache invalidations into a - single replication poke. - - Args: - txn - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have changed - """ - txn.call_after(self._invalidate_state_caches, room_id, members_changed) - - if members_changed: - # We need to be careful that the size of the `members_changed` list - # isn't so large that it causes problems sending over replication, so we - # send them in chunks. - # Max line length is 16K, and max user ID length is 255, so 50 should - # be safe. - for chunk in batch_iter(members_changed, 50): - keys = itertools.chain([room_id], chunk) - self._send_invalidation_to_replication( - txn, CURRENT_STATE_CACHE_NAME, keys - ) - else: - # if no members changed, we still need to invalidate the other caches. - self._send_invalidation_to_replication( - txn, CURRENT_STATE_CACHE_NAME, [room_id] - ) - - def _send_invalidation_to_replication( - self, txn, cache_name: str, keys: Optional[Iterable[Any]] - ): - """Notifies replication that given cache has been invalidated. - - Note that this does *not* invalidate the cache locally. - - Args: - txn - cache_name - keys: Entry to invalidate. If None will invalidate all. - """ - - if cache_name == CURRENT_STATE_CACHE_NAME and keys is None: - raise Exception( - "Can't stream invalidate all with magic current state cache" - ) - - if isinstance(self.database_engine, PostgresEngine): - # get_next() returns a context manager which is designed to wrap - # the transaction. However, we want to only get an ID when we want - # to use it, here, so we need to call __enter__ manually, and have - # __exit__ called after the transaction finishes. - stream_id = self._cache_id_gen.get_next_txn(txn) - txn.call_after(self.hs.get_notifier().on_new_replication_data) - - if keys is not None: - keys = list(keys) - - self.db.simple_insert_txn( - txn, - table="cache_invalidation_stream_by_instance", - values={ - "stream_id": stream_id, - "instance_name": self._instance_name, - "cache_func": cache_name, - "keys": keys, - "invalidation_ts": self.clock.time_msec(), - }, - ) - - def get_cache_stream_token(self, instance_name): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token(instance_name) - else: - return 0 diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py deleted file mode 100644 index 2d48261724..0000000000 --- a/synapse/storage/data_stores/main/censor_events.py +++ /dev/null @@ -1,208 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING - -from twisted.internet import defer - -from synapse.events.utils import prune_event_dict -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore -from synapse.storage.data_stores.main.events import encode_json -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -logger = logging.getLogger(__name__) - - -class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): - def __init__(self, database: Database, db_conn, hs: "HomeServer"): - super().__init__(database, db_conn, hs) - - def _censor_redactions(): - return run_as_background_process( - "_censor_redactions", self._censor_redactions - ) - - if self.hs.config.redaction_retention_period is not None: - hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) - - async def _censor_redactions(self): - """Censors all redactions older than the configured period that haven't - been censored yet. - - By censor we mean update the event_json table with the redacted event. - """ - - if self.hs.config.redaction_retention_period is None: - return - - if not ( - await self.db.updates.has_completed_background_update( - "redactions_have_censored_ts_idx" - ) - ): - # We don't want to run this until the appropriate index has been - # created. - return - - before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period - - # We fetch all redactions that: - # 1. point to an event we have, - # 2. has a received_ts from before the cut off, and - # 3. we haven't yet censored. - # - # This is limited to 100 events to ensure that we don't try and do too - # much at once. We'll get called again so this should eventually catch - # up. - sql = """ - SELECT redactions.event_id, redacts FROM redactions - LEFT JOIN events AS original_event ON ( - redacts = original_event.event_id - ) - WHERE NOT have_censored - AND redactions.received_ts <= ? - ORDER BY redactions.received_ts ASC - LIMIT ? - """ - - rows = await self.db.execute( - "_censor_redactions_fetch", None, sql, before_ts, 100 - ) - - updates = [] - - for redaction_id, event_id in rows: - redaction_event = await self.get_event(redaction_id, allow_none=True) - original_event = await self.get_event( - event_id, allow_rejected=True, allow_none=True - ) - - # The SQL above ensures that we have both the redaction and - # original event, so if the `get_event` calls return None it - # means that the redaction wasn't allowed. Either way we know that - # the result won't change so we mark the fact that we've checked. - if ( - redaction_event - and original_event - and original_event.internal_metadata.is_redacted() - ): - # Redaction was allowed - pruned_json = encode_json( - prune_event_dict( - original_event.room_version, original_event.get_dict() - ) - ) - else: - # Redaction wasn't allowed - pruned_json = None - - updates.append((redaction_id, event_id, pruned_json)) - - def _update_censor_txn(txn): - for redaction_id, event_id, pruned_json in updates: - if pruned_json: - self._censor_event_txn(txn, event_id, pruned_json) - - self.db.simple_update_one_txn( - txn, - table="redactions", - keyvalues={"event_id": redaction_id}, - updatevalues={"have_censored": True}, - ) - - await self.db.runInteraction("_update_censor_txn", _update_censor_txn) - - def _censor_event_txn(self, txn, event_id, pruned_json): - """Censor an event by replacing its JSON in the event_json table with the - provided pruned JSON. - - Args: - txn (LoggingTransaction): The database transaction. - event_id (str): The ID of the event to censor. - pruned_json (str): The pruned JSON - """ - self.db.simple_update_one_txn( - txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, - ) - - @defer.inlineCallbacks - def expire_event(self, event_id): - """Retrieve and expire an event that has expired, and delete its associated - expiry timestamp. If the event can't be retrieved, delete its associated - timestamp so we don't try to expire it again in the future. - - Args: - event_id (str): The ID of the event to delete. - """ - # Try to retrieve the event's content from the database or the event cache. - event = yield self.get_event(event_id) - - def delete_expired_event_txn(txn): - # Delete the expiry timestamp associated with this event from the database. - self._delete_event_expiry_txn(txn, event_id) - - if not event: - # If we can't find the event, log a warning and delete the expiry date - # from the database so that we don't try to expire it again in the - # future. - logger.warning( - "Can't expire event %s because we don't have it.", event_id - ) - return - - # Prune the event's dict then convert it to JSON. - pruned_json = encode_json( - prune_event_dict(event.room_version, event.get_dict()) - ) - - # Update the event_json table to replace the event's JSON with the pruned - # JSON. - self._censor_event_txn(txn, event.event_id, pruned_json) - - # We need to invalidate the event cache entry for this event because we - # changed its content in the database. We can't call - # self._invalidate_cache_and_stream because self.get_event_cache isn't of the - # right type. - txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) - # Send that invalidation to replication so that other workers also invalidate - # the event cache. - self._send_invalidation_to_replication( - txn, "_get_event_cache", (event.event_id,) - ) - - yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) - - def _delete_event_expiry_txn(self, txn, event_id): - """Delete the expiry timestamp associated with an event ID without deleting the - actual event. - - Args: - txn (LoggingTransaction): The transaction to use to perform the deletion. - event_id (str): The event ID to delete the associated expiry timestamp of. - """ - return self.db.simple_delete_txn( - txn=txn, table="event_expiry", keyvalues={"event_id": event_id} - ) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py deleted file mode 100644 index 995d4764a9..0000000000 --- a/synapse/storage/data_stores/main/client_ips.py +++ /dev/null @@ -1,576 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from twisted.internet import defer - -from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database, make_tuple_comparison_clause -from synapse.util.caches.descriptors import Cache - -logger = logging.getLogger(__name__) - -# Number of msec of granularity to store the user IP 'last seen' time. Smaller -# times give more inserts into the database even for readonly API hits -# 120 seconds == 2 minutes -LAST_SEEN_GRANULARITY = 120 * 1000 - - -class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) - - self.db.updates.register_background_index_update( - "user_ips_device_index", - index_name="user_ips_device_id", - table="user_ips", - columns=["user_id", "device_id", "last_seen"], - ) - - self.db.updates.register_background_index_update( - "user_ips_last_seen_index", - index_name="user_ips_last_seen", - table="user_ips", - columns=["user_id", "last_seen"], - ) - - self.db.updates.register_background_index_update( - "user_ips_last_seen_only_index", - index_name="user_ips_last_seen_only", - table="user_ips", - columns=["last_seen"], - ) - - self.db.updates.register_background_update_handler( - "user_ips_analyze", self._analyze_user_ip - ) - - self.db.updates.register_background_update_handler( - "user_ips_remove_dupes", self._remove_user_ip_dupes - ) - - # Register a unique index - self.db.updates.register_background_index_update( - "user_ips_device_unique_index", - index_name="user_ips_user_token_ip_unique_index", - table="user_ips", - columns=["user_id", "access_token", "ip"], - unique=True, - ) - - # Drop the old non-unique index - self.db.updates.register_background_update_handler( - "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique - ) - - # Update the last seen info in devices. - self.db.updates.register_background_update_handler( - "devices_last_seen", self._devices_last_seen_update - ) - - @defer.inlineCallbacks - def _remove_user_ip_nonunique(self, progress, batch_size): - def f(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") - txn.close() - - yield self.db.runWithConnection(f) - yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") - return 1 - - @defer.inlineCallbacks - def _analyze_user_ip(self, progress, batch_size): - # Background update to analyze user_ips table before we run the - # deduplication background update. The table may not have been analyzed - # for ages due to the table locks. - # - # This will lock out the naive upserts to user_ips while it happens, but - # the analyze should be quick (28GB table takes ~10s) - def user_ips_analyze(txn): - txn.execute("ANALYZE user_ips") - - yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) - - yield self.db.updates._end_background_update("user_ips_analyze") - - return 1 - - @defer.inlineCallbacks - def _remove_user_ip_dupes(self, progress, batch_size): - # This works function works by scanning the user_ips table in batches - # based on `last_seen`. For each row in a batch it searches the rest of - # the table to see if there are any duplicates, if there are then they - # are removed and replaced with a suitable row. - - # Fetch the start of the batch - begin_last_seen = progress.get("last_seen", 0) - - def get_last_seen(txn): - txn.execute( - """ - SELECT last_seen FROM user_ips - WHERE last_seen > ? - ORDER BY last_seen - LIMIT 1 - OFFSET ? - """, - (begin_last_seen, batch_size), - ) - row = txn.fetchone() - if row: - return row[0] - else: - return None - - # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.db.runInteraction( - "user_ips_dups_get_last_seen", get_last_seen - ) - - # If it returns None, then we're processing the last batch - last = end_last_seen is None - - logger.info( - "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s", - begin_last_seen, - end_last_seen, - ) - - def remove(txn): - # This works by looking at all entries in the given time span, and - # then for each (user_id, access_token, ip) tuple in that range - # checking for any duplicates in the rest of the table (via a join). - # It then only returns entries which have duplicates, and the max - # last_seen across all duplicates, which can the be used to delete - # all other duplicates. - # It is efficient due to the existence of (user_id, access_token, - # ip) and (last_seen) indices. - - # Define the search space, which requires handling the last batch in - # a different way - if last: - clause = "? <= last_seen" - args = (begin_last_seen,) - else: - clause = "? <= last_seen AND last_seen < ?" - args = (begin_last_seen, end_last_seen) - - # (Note: The DISTINCT in the inner query is important to ensure that - # the COUNT(*) is accurate, otherwise double counting may happen due - # to the join effectively being a cross product) - txn.execute( - """ - SELECT user_id, access_token, ip, - MAX(device_id), MAX(user_agent), MAX(last_seen), - COUNT(*) - FROM ( - SELECT DISTINCT user_id, access_token, ip - FROM user_ips - WHERE {} - ) c - INNER JOIN user_ips USING (user_id, access_token, ip) - GROUP BY user_id, access_token, ip - HAVING count(*) > 1 - """.format( - clause - ), - args, - ) - res = txn.fetchall() - - # We've got some duplicates - for i in res: - user_id, access_token, ip, device_id, user_agent, last_seen, count = i - - # We want to delete the duplicates so we end up with only a - # single row. - # - # The naive way of doing this would be just to delete all rows - # and reinsert a constructed row. However, if there are a lot of - # duplicate rows this can cause the table to grow a lot, which - # can be problematic in two ways: - # 1. If user_ips is already large then this can cause the - # table to rapidly grow, potentially filling the disk. - # 2. Reinserting a lot of rows can confuse the table - # statistics for postgres, causing it to not use the - # correct indices for the query above, resulting in a full - # table scan. This is incredibly slow for large tables and - # can kill database performance. (This seems to mainly - # happen for the last query where the clause is simply `? < - # last_seen`) - # - # So instead we want to delete all but *one* of the duplicate - # rows. That is hard to do reliably, so we cheat and do a two - # step process: - # 1. Delete all rows with a last_seen strictly less than the - # max last_seen. This hopefully results in deleting all but - # one row the majority of the time, but there may be - # duplicate last_seen - # 2. If multiple rows remain, we fall back to the naive method - # and simply delete all rows and reinsert. - # - # Note that this relies on no new duplicate rows being inserted, - # but if that is happening then this entire process is futile - # anyway. - - # Do step 1: - - txn.execute( - """ - DELETE FROM user_ips - WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ? - """, - (user_id, access_token, ip, last_seen), - ) - if txn.rowcount == count - 1: - # We deleted all but one of the duplicate rows, i.e. there - # is exactly one remaining and so there is nothing left to - # do. - continue - elif txn.rowcount >= count: - raise Exception( - "We deleted more duplicate rows from 'user_ips' than expected" - ) - - # The previous step didn't delete enough rows, so we fallback to - # step 2: - - # Drop all the duplicates - txn.execute( - """ - DELETE FROM user_ips - WHERE user_id = ? AND access_token = ? AND ip = ? - """, - (user_id, access_token, ip), - ) - - # Add in one to be the last_seen - txn.execute( - """ - INSERT INTO user_ips - (user_id, access_token, ip, device_id, user_agent, last_seen) - VALUES (?, ?, ?, ?, ?, ?) - """, - (user_id, access_token, ip, device_id, user_agent, last_seen), - ) - - self.db.updates._background_update_progress_txn( - txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} - ) - - yield self.db.runInteraction("user_ips_dups_remove", remove) - - if last: - yield self.db.updates._end_background_update("user_ips_remove_dupes") - - return batch_size - - @defer.inlineCallbacks - def _devices_last_seen_update(self, progress, batch_size): - """Background update to insert last seen info into devices table - """ - - last_user_id = progress.get("last_user_id", "") - last_device_id = progress.get("last_device_id", "") - - def _devices_last_seen_update_txn(txn): - # This consists of two queries: - # - # 1. The sub-query searches for the next N devices and joins - # against user_ips to find the max last_seen associated with - # that device. - # 2. The outer query then joins again against user_ips on - # user/device/last_seen. This *should* hopefully only - # return one row, but if it does return more than one then - # we'll just end up updating the same device row multiple - # times, which is fine. - - where_clause, where_args = make_tuple_comparison_clause( - self.database_engine, - [("user_id", last_user_id), ("device_id", last_device_id)], - ) - - sql = """ - SELECT - last_seen, ip, user_agent, user_id, device_id - FROM ( - SELECT - user_id, device_id, MAX(u.last_seen) AS last_seen - FROM devices - INNER JOIN user_ips AS u USING (user_id, device_id) - WHERE %(where_clause)s - GROUP BY user_id, device_id - ORDER BY user_id ASC, device_id ASC - LIMIT ? - ) c - INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) - """ % { - "where_clause": where_clause - } - txn.execute(sql, where_args + [batch_size]) - - rows = txn.fetchall() - if not rows: - return 0 - - sql = """ - UPDATE devices - SET last_seen = ?, ip = ?, user_agent = ? - WHERE user_id = ? AND device_id = ? - """ - txn.execute_batch(sql, rows) - - _, _, _, user_id, device_id = rows[-1] - self.db.updates._background_update_progress_txn( - txn, - "devices_last_seen", - {"last_user_id": user_id, "last_device_id": device_id}, - ) - - return len(rows) - - updated = yield self.db.runInteraction( - "_devices_last_seen_update", _devices_last_seen_update_txn - ) - - if not updated: - yield self.db.updates._end_background_update("devices_last_seen") - - return updated - - -class ClientIpStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): - - self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 - ) - - super(ClientIpStore, self).__init__(database, db_conn, hs) - - self.user_ips_max_age = hs.config.user_ips_max_age - - # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) - self._batch_row_update = {} - - self._client_ip_looper = self._clock.looping_call( - self._update_client_ips_batch, 5 * 1000 - ) - self.hs.get_reactor().addSystemEventTrigger( - "before", "shutdown", self._update_client_ips_batch - ) - - if self.user_ips_max_age: - self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - - @defer.inlineCallbacks - def insert_client_ip( - self, user_id, access_token, ip, user_agent, device_id, now=None - ): - if not now: - now = int(self._clock.time_msec()) - key = (user_id, access_token, ip) - - try: - last_seen = self.client_ip_last_seen.get(key) - except KeyError: - last_seen = None - yield self.populate_monthly_active_users(user_id) - # Rate-limited inserts - if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: - return - - self.client_ip_last_seen.prefill(key, now) - - self._batch_row_update[key] = (user_agent, device_id, now) - - @wrap_as_background_process("update_client_ips") - def _update_client_ips_batch(self): - - # If the DB pool has already terminated, don't try updating - if not self.db.is_running(): - return - - to_update = self._batch_row_update - self._batch_row_update = {} - - return self.db.runInteraction( - "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update - ) - - def _update_client_ips_batch_txn(self, txn, to_update): - if "user_ips" in self.db._unsafe_to_upsert_tables or ( - not self.database_engine.can_native_upsert - ): - self.database_engine.lock_table(txn, "user_ips") - - for entry in to_update.items(): - (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - - try: - self.db.simple_upsert_txn( - txn, - table="user_ips", - keyvalues={ - "user_id": user_id, - "access_token": access_token, - "ip": ip, - }, - values={ - "user_agent": user_agent, - "device_id": device_id, - "last_seen": last_seen, - }, - lock=False, - ) - - # Technically an access token might not be associated with - # a device so we need to check. - if device_id: - # this is always an update rather than an upsert: the row should - # already exist, and if it doesn't, that may be because it has been - # deleted, and we don't want to re-create it. - self.db.simple_update_txn( - txn, - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - updatevalues={ - "user_agent": user_agent, - "last_seen": last_seen, - "ip": ip, - }, - ) - except Exception as e: - # Failed to upsert, log and continue - logger.error("Failed to insert client IP %r: %r", entry, e) - - @defer.inlineCallbacks - def get_last_client_ip_by_device(self, user_id, device_id): - """For each device_id listed, give the user_ip it was last seen on - - Args: - user_id (str) - device_id (str): If None fetches all devices for the user - - Returns: - defer.Deferred: resolves to a dict, where the keys - are (user_id, device_id) tuples. The values are also dicts, with - keys giving the column names - """ - - keyvalues = {"user_id": user_id} - if device_id is not None: - keyvalues["device_id"] = device_id - - res = yield self.db.simple_select_list( - table="devices", - keyvalues=keyvalues, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), - ) - - ret = {(d["user_id"], d["device_id"]): d for d in res} - for key in self._batch_row_update: - uid, access_token, ip = key - if uid == user_id: - user_agent, did, last_seen = self._batch_row_update[key] - if not device_id or did == device_id: - ret[(user_id, device_id)] = { - "user_id": user_id, - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "device_id": did, - "last_seen": last_seen, - } - return ret - - @defer.inlineCallbacks - def get_user_ip_and_agents(self, user): - user_id = user.to_string() - results = {} - - for key in self._batch_row_update: - uid, access_token, ip, = key - if uid == user_id: - user_agent, _, last_seen = self._batch_row_update[key] - results[(access_token, ip)] = (user_agent, last_seen) - - rows = yield self.db.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "last_seen"], - desc="get_user_ip_and_agents", - ) - - results.update( - ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) - for row in rows - ) - return [ - { - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "last_seen": last_seen, - } - for (access_token, ip), (user_agent, last_seen) in results.items() - ] - - @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): - """Removes entries in user IPs older than the configured period. - """ - - if self.user_ips_max_age is None: - # Nothing to do - return - - if not await self.db.updates.has_completed_background_update( - "devices_last_seen" - ): - # Only start pruning if we have finished populating the devices - # last seen info. - return - - # We do a slightly funky SQL delete to ensure we don't try and delete - # too much at once (as the table may be very large from before we - # started pruning). - # - # This works by finding the max last_seen that is less than the given - # time, but has no more than N rows before it, deleting all rows with - # a lesser last_seen time. (We COALESCE so that the sub-SELECT always - # returns exactly one row). - sql = """ - DELETE FROM user_ips - WHERE last_seen <= ( - SELECT COALESCE(MAX(last_seen), -1) - FROM ( - SELECT last_seen FROM user_ips - WHERE last_seen <= ? - ORDER BY last_seen ASC - LIMIT 5000 - ) AS u - ) - """ - - timestamp = self.clock.time_msec() - self.user_ips_max_age - - def _prune_old_user_ips_txn(txn): - txn.execute(sql, (timestamp,)) - - await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py deleted file mode 100644 index da297b31fb..0000000000 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ /dev/null @@ -1,476 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import List, Tuple - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database -from synapse.util.caches.expiringcache import ExpiringCache - -logger = logging.getLogger(__name__) - - -class DeviceInboxWorkerStore(SQLBaseStore): - def get_to_device_stream_token(self): - return self._device_inbox_id_gen.get_current_token() - - def get_new_messages_for_device( - self, user_id, device_id, last_stream_id, current_stream_id, limit=100 - ): - """ - Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - current_stream_id(int): The current position of the to device - message stream. - Returns: - Deferred ([dict], int): List of messages for the device and where - in the stream the messages got to. - """ - has_changed = self._device_inbox_stream_cache.has_entity_changed( - user_id, last_stream_id - ) - if not has_changed: - return defer.succeed(([], current_stream_id)) - - def get_new_messages_for_device_txn(txn): - sql = ( - "SELECT stream_id, message_json FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - " LIMIT ?" - ) - txn.execute( - sql, (user_id, device_id, last_stream_id, current_stream_id, limit) - ) - messages = [] - for row in txn: - stream_pos = row[0] - messages.append(db_to_json(row[1])) - if len(messages) < limit: - stream_pos = current_stream_id - return messages, stream_pos - - return self.db.runInteraction( - "get_new_messages_for_device", get_new_messages_for_device_txn - ) - - @trace - @defer.inlineCallbacks - def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): - """ - Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves to the number of messages deleted. - """ - # If we have cached the last stream id we've deleted up to, we can - # check if there is likely to be anything that needs deleting - last_deleted_stream_id = self._last_device_delete_cache.get( - (user_id, device_id), None - ) - - set_tag("last_deleted_stream_id", last_deleted_stream_id) - - if last_deleted_stream_id: - has_changed = self._device_inbox_stream_cache.has_entity_changed( - user_id, last_deleted_stream_id - ) - if not has_changed: - log_kv({"message": "No changes in cache since last check"}) - return 0 - - def delete_messages_for_device_txn(txn): - sql = ( - "DELETE FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND stream_id <= ?" - ) - txn.execute(sql, (user_id, device_id, up_to_stream_id)) - return txn.rowcount - - count = yield self.db.runInteraction( - "delete_messages_for_device", delete_messages_for_device_txn - ) - - log_kv( - {"message": "deleted {} messages for device".format(count), "count": count} - ) - - # Update the cache, ensuring that we only ever increase the value - last_deleted_stream_id = self._last_device_delete_cache.get( - (user_id, device_id), 0 - ) - self._last_device_delete_cache[(user_id, device_id)] = max( - last_deleted_stream_id, up_to_stream_id - ) - - return count - - @trace - def get_new_device_msgs_for_remote( - self, destination, last_stream_id, current_stream_id, limit - ): - """ - Args: - destination(str): The name of the remote server. - last_stream_id(int|long): The last position of the device message stream - that the server sent up to. - current_stream_id(int|long): The current position of the device - message stream. - Returns: - Deferred ([dict], int|long): List of messages for the device and where - in the stream the messages got to. - """ - - set_tag("destination", destination) - set_tag("last_stream_id", last_stream_id) - set_tag("current_stream_id", current_stream_id) - set_tag("limit", limit) - - has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( - destination, last_stream_id - ) - if not has_changed or last_stream_id == current_stream_id: - log_kv({"message": "No new messages in stream"}) - return defer.succeed(([], current_stream_id)) - - if limit <= 0: - # This can happen if we run out of room for EDUs in the transaction. - return defer.succeed(([], last_stream_id)) - - @trace - def get_new_messages_for_remote_destination_txn(txn): - sql = ( - "SELECT stream_id, messages_json FROM device_federation_outbox" - " WHERE destination = ?" - " AND ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - " LIMIT ?" - ) - txn.execute(sql, (destination, last_stream_id, current_stream_id, limit)) - messages = [] - for row in txn: - stream_pos = row[0] - messages.append(db_to_json(row[1])) - if len(messages) < limit: - log_kv({"message": "Set stream position to current position"}) - stream_pos = current_stream_id - return messages, stream_pos - - return self.db.runInteraction( - "get_new_device_msgs_for_remote", - get_new_messages_for_remote_destination_txn, - ) - - @trace - def delete_device_msgs_for_remote(self, destination, up_to_stream_id): - """Used to delete messages when the remote destination acknowledges - their receipt. - - Args: - destination(str): The destination server_name - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves when the messages have been deleted. - """ - - def delete_messages_for_remote_destination_txn(txn): - sql = ( - "DELETE FROM device_federation_outbox" - " WHERE destination = ?" - " AND stream_id <= ?" - ) - txn.execute(sql, (destination, up_to_stream_id)) - - return self.db.runInteraction( - "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn - ) - - async def get_all_new_device_messages( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for to device replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - if last_id == current_id: - return [], current_id, False - - def get_all_new_device_messages_txn(txn): - # We limit like this as we might have multiple rows per stream_id, and - # we want to make sure we always get all entries for any stream_id - # we return. - upper_pos = min(current_id, last_id + limit) - sql = ( - "SELECT max(stream_id), user_id" - " FROM device_inbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY user_id" - ) - txn.execute(sql, (last_id, upper_pos)) - updates = [(row[0], row[1:]) for row in txn] - - sql = ( - "SELECT max(stream_id), destination" - " FROM device_federation_outbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY destination" - ) - txn.execute(sql, (last_id, upper_pos)) - updates.extend((row[0], row[1:]) for row in txn) - - # Order by ascending stream ordering - updates.sort() - - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db.runInteraction( - "get_all_new_device_messages", get_all_new_device_messages_txn - ) - - -class DeviceInboxBackgroundUpdateStore(SQLBaseStore): - DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - - def __init__(self, database: Database, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) - - self.db.updates.register_background_index_update( - "device_inbox_stream_index", - index_name="device_inbox_stream_id_user_id", - table="device_inbox", - columns=["stream_id", "user_id"], - ) - - self.db.updates.register_background_update_handler( - self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox - ) - - @defer.inlineCallbacks - def _background_drop_index_device_inbox(self, progress, batch_size): - def reindex_txn(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") - txn.close() - - yield self.db.runWithConnection(reindex_txn) - - yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) - - return 1 - - -class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): - DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - - def __init__(self, database: Database, db_conn, hs): - super(DeviceInboxStore, self).__init__(database, db_conn, hs) - - # Map of (user_id, device_id) to the last stream_id that has been - # deleted up to. This is so that we can no op deletions. - self._last_device_delete_cache = ExpiringCache( - cache_name="last_device_delete_cache", - clock=self._clock, - max_len=10000, - expiry_ms=30 * 60 * 1000, - ) - - @trace - @defer.inlineCallbacks - def add_messages_to_device_inbox( - self, local_messages_by_user_then_device, remote_messages_by_destination - ): - """Used to send messages from this server. - - Args: - sender_user_id(str): The ID of the user sending these messages. - local_messages_by_user_and_device(dict): - Dictionary of user_id to device_id to message. - remote_messages_by_destination(dict): - Dictionary of destination server_name to the EDU JSON to send. - Returns: - A deferred stream_id that resolves when the messages have been - inserted. - """ - - def add_messages_txn(txn, now_ms, stream_id): - # Add the local messages directly to the local inbox. - self._add_messages_to_local_device_inbox_txn( - txn, stream_id, local_messages_by_user_then_device - ) - - # Add the remote messages to the federation outbox. - # We'll send them to a remote server when we next send a - # federation transaction to that destination. - sql = ( - "INSERT INTO device_federation_outbox" - " (destination, stream_id, queued_ts, messages_json)" - " VALUES (?,?,?,?)" - ) - rows = [] - for destination, edu in remote_messages_by_destination.items(): - edu_json = json.dumps(edu) - rows.append((destination, stream_id, now_ms, edu_json)) - txn.executemany(sql, rows) - - with self._device_inbox_id_gen.get_next() as stream_id: - now_ms = self.clock.time_msec() - yield self.db.runInteraction( - "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id - ) - for user_id in local_messages_by_user_then_device.keys(): - self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) - for destination in remote_messages_by_destination.keys(): - self._device_federation_outbox_stream_cache.entity_has_changed( - destination, stream_id - ) - - return self._device_inbox_id_gen.get_current_token() - - @defer.inlineCallbacks - def add_messages_from_remote_to_device_inbox( - self, origin, message_id, local_messages_by_user_then_device - ): - def add_messages_txn(txn, now_ms, stream_id): - # Check if we've already inserted a matching message_id for that - # origin. This can happen if the origin doesn't receive our - # acknowledgement from the first time we received the message. - already_inserted = self.db.simple_select_one_txn( - txn, - table="device_federation_inbox", - keyvalues={"origin": origin, "message_id": message_id}, - retcols=("message_id",), - allow_none=True, - ) - if already_inserted is not None: - return - - # Add an entry for this message_id so that we know we've processed - # it. - self.db.simple_insert_txn( - txn, - table="device_federation_inbox", - values={ - "origin": origin, - "message_id": message_id, - "received_ts": now_ms, - }, - ) - - # Add the messages to the approriate local device inboxes so that - # they'll be sent to the devices when they next sync. - self._add_messages_to_local_device_inbox_txn( - txn, stream_id, local_messages_by_user_then_device - ) - - with self._device_inbox_id_gen.get_next() as stream_id: - now_ms = self.clock.time_msec() - yield self.db.runInteraction( - "add_messages_from_remote_to_device_inbox", - add_messages_txn, - now_ms, - stream_id, - ) - for user_id in local_messages_by_user_then_device.keys(): - self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) - - return stream_id - - def _add_messages_to_local_device_inbox_txn( - self, txn, stream_id, messages_by_user_then_device - ): - local_by_user_then_device = {} - for user_id, messages_by_device in messages_by_user_then_device.items(): - messages_json_for_user = {} - devices = list(messages_by_device.keys()) - if len(devices) == 1 and devices[0] == "*": - # Handle wildcard device_ids. - sql = "SELECT device_id FROM devices WHERE user_id = ?" - txn.execute(sql, (user_id,)) - message_json = json.dumps(messages_by_device["*"]) - for row in txn: - # Add the message for all devices for this user on this - # server. - device = row[0] - messages_json_for_user[device] = message_json - else: - if not devices: - continue - - clause, args = make_in_list_sql_clause( - txn.database_engine, "device_id", devices - ) - sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause - - # TODO: Maybe this needs to be done in batches if there are - # too many local devices for a given user. - txn.execute(sql, [user_id] + list(args)) - for row in txn: - # Only insert into the local inbox if the device exists on - # this server - device = row[0] - message_json = json.dumps(messages_by_device[device]) - messages_json_for_user[device] = message_json - - if messages_json_for_user: - local_by_user_then_device[user_id] = messages_json_for_user - - if not local_by_user_then_device: - return - - sql = ( - "INSERT INTO device_inbox" - " (user_id, device_id, stream_id, message_json)" - " VALUES (?,?,?,?)" - ) - rows = [] - for user_id, messages_by_device in local_by_user_then_device.items(): - for device_id, message_json in messages_by_device.items(): - rows.append((user_id, device_id, stream_id, message_json)) - - txn.executemany(sql, rows) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py deleted file mode 100644 index 45581a6500..0000000000 --- a/synapse/storage/data_stores/main/devices.py +++ /dev/null @@ -1,1309 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import List, Optional, Set, Tuple - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.errors import Codes, StoreError -from synapse.logging.opentracing import ( - get_active_span_text_map, - set_tag, - trace, - whitelisted_homeserver, -) -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import ( - Database, - LoggingTransaction, - make_tuple_comparison_clause, -) -from synapse.types import Collection, get_verify_key_from_cross_signing_key -from synapse.util.caches.descriptors import ( - Cache, - cached, - cachedInlineCallbacks, - cachedList, -) -from synapse.util.iterutils import batch_iter -from synapse.util.stringutils import shortstr - -logger = logging.getLogger(__name__) - -DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( - "drop_device_list_streams_non_unique_indexes" -) - -BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" - - -class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id, device_id): - """Retrieve a device. Only returns devices that are not marked as - hidden. - - Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve - Returns: - defer.Deferred for a dict containing the device information - Raises: - StoreError: if the device is not found - """ - return self.db.simple_select_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - retcols=("user_id", "device_id", "display_name"), - desc="get_device", - ) - - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): - """Retrieve all of a user's registered devices. Only returns devices - that are not marked as hidden. - - Args: - user_id (str): - Returns: - defer.Deferred: resolves to a dict from device_id to a dict - containing "device_id", "user_id" and "display_name" for each - device. - """ - devices = yield self.db.simple_select_list( - table="devices", - keyvalues={"user_id": user_id, "hidden": False}, - retcols=("user_id", "device_id", "display_name"), - desc="get_devices_by_user", - ) - - return {d["device_id"]: d for d in devices} - - @trace - @defer.inlineCallbacks - def get_device_updates_by_remote(self, destination, from_stream_id, limit): - """Get a stream of device updates to send to the given remote server. - - Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - limit (int): Maximum number of device updates to return - Returns: - Deferred[tuple[int, list[tuple[string,dict]]]]: - current stream id (ie, the stream id of the last update included in the - response), and the list of updates, where each update is a pair of EDU - type and EDU contents - """ - now_stream_id = self._device_list_id_gen.get_current_token() - - has_changed = self._device_list_federation_stream_cache.has_entity_changed( - destination, int(from_stream_id) - ) - if not has_changed: - return now_stream_id, [] - - updates = yield self.db.runInteraction( - "get_device_updates_by_remote", - self._get_device_updates_by_remote_txn, - destination, - from_stream_id, - now_stream_id, - limit, - ) - - # Return an empty list if there are no updates - if not updates: - return now_stream_id, [] - - # get the cross-signing keys of the users in the list, so that we can - # determine which of the device changes were cross-signing keys - users = {r[0] for r in updates} - master_key_by_user = {} - self_signing_key_by_user = {} - for user in users: - cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") - if cross_signing_key: - key_id, verify_key = get_verify_key_from_cross_signing_key( - cross_signing_key - ) - # verify_key is a VerifyKey from signedjson, which uses - # .version to denote the portion of the key ID after the - # algorithm and colon, which is the device ID - master_key_by_user[user] = { - "key_info": cross_signing_key, - "device_id": verify_key.version, - } - - cross_signing_key = yield self.get_e2e_cross_signing_key( - user, "self_signing" - ) - if cross_signing_key: - key_id, verify_key = get_verify_key_from_cross_signing_key( - cross_signing_key - ) - self_signing_key_by_user[user] = { - "key_info": cross_signing_key, - "device_id": verify_key.version, - } - - # Perform the equivalent of a GROUP BY - # - # Iterate through the updates list and copy non-duplicate - # (user_id, device_id) entries into a map, with the value being - # the max stream_id across each set of duplicate entries - # - # maps (user_id, device_id) -> (stream_id, opentracing_context) - # - # opentracing_context contains the opentracing metadata for the request - # that created the poke - # - # The most recent request's opentracing_context is used as the - # context which created the Edu. - - query_map = {} - cross_signing_keys_by_user = {} - for user_id, device_id, update_stream_id, update_context in updates: - if ( - user_id in master_key_by_user - and device_id == master_key_by_user[user_id]["device_id"] - ): - result = cross_signing_keys_by_user.setdefault(user_id, {}) - result["master_key"] = master_key_by_user[user_id]["key_info"] - elif ( - user_id in self_signing_key_by_user - and device_id == self_signing_key_by_user[user_id]["device_id"] - ): - result = cross_signing_keys_by_user.setdefault(user_id, {}) - result["self_signing_key"] = self_signing_key_by_user[user_id][ - "key_info" - ] - else: - key = (user_id, device_id) - - previous_update_stream_id, _ = query_map.get(key, (0, None)) - - if update_stream_id > previous_update_stream_id: - query_map[key] = (update_stream_id, update_context) - - results = yield self._get_device_update_edus_by_remote( - destination, from_stream_id, query_map - ) - - # add the updated cross-signing keys to the results list - for user_id, result in cross_signing_keys_by_user.items(): - result["user_id"] = user_id - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec - results.append(("org.matrix.signing_key_update", result)) - - return now_stream_id, results - - def _get_device_updates_by_remote_txn( - self, txn, destination, from_stream_id, now_stream_id, limit - ): - """Return device update information for a given remote destination - - Args: - txn (LoggingTransaction): The transaction to execute - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - now_stream_id (int): The maximum stream_id to filter updates by, inclusive - limit (int): Maximum number of device updates to return - - Returns: - List: List of device updates - """ - # get the list of device updates that need to be sent - sql = """ - SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes - WHERE destination = ? AND ? < stream_id AND stream_id <= ? - ORDER BY stream_id - LIMIT ? - """ - txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) - - return list(txn) - - @defer.inlineCallbacks - def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): - """Returns a list of device update EDUs as well as E2EE keys - - Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping - user_id/device_id to update stream_id and the relevent json-encoded - opentracing context - - Returns: - List[Dict]: List of objects representing an device update EDU - - """ - devices = ( - yield self.db.runInteraction( - "_get_e2e_device_keys_txn", - self._get_e2e_device_keys_txn, - query_map.keys(), - include_all_devices=True, - include_deleted_devices=True, - ) - if query_map - else {} - ) - - results = [] - for user_id, user_devices in devices.items(): - # The prev_id for the first row is always the last row before - # `from_stream_id` - prev_id = yield self._get_last_device_update_for_remote_user( - destination, user_id, from_stream_id - ) - - # make sure we go through the devices in stream order - device_ids = sorted( - user_devices.keys(), key=lambda i: query_map[(user_id, i)][0], - ) - - for device_id in device_ids: - device = user_devices[device_id] - stream_id, opentracing_context = query_map[(user_id, device_id)] - result = { - "user_id": user_id, - "device_id": device_id, - "prev_id": [prev_id] if prev_id else [], - "stream_id": stream_id, - "org.matrix.opentracing_context": opentracing_context, - } - - prev_id = stream_id - - if device is not None: - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - else: - result["deleted"] = True - - results.append(("m.device_list_update", result)) - - return results - - def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id - ): - def f(txn): - prev_sent_id_sql = """ - SELECT coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_last_success - WHERE destination = ? AND user_id = ? AND stream_id <= ? - """ - txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) - rows = txn.fetchall() - return rows[0][0] - - return self.db.runInteraction("get_last_device_update_for_remote_user", f) - - def mark_as_sent_devices_by_remote(self, destination, stream_id): - """Mark that updates have successfully been sent to the destination. - """ - return self.db.runInteraction( - "mark_as_sent_devices_by_remote", - self._mark_as_sent_devices_by_remote_txn, - destination, - stream_id, - ) - - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): - # We update the device_lists_outbound_last_success with the successfully - # poked users. - sql = """ - SELECT user_id, coalesce(max(o.stream_id), 0) - FROM device_lists_outbound_pokes as o - WHERE destination = ? AND o.stream_id <= ? - GROUP BY user_id - """ - txn.execute(sql, (destination, stream_id)) - rows = txn.fetchall() - - self.db.simple_upsert_many_txn( - txn=txn, - table="device_lists_outbound_last_success", - key_names=("destination", "user_id"), - key_values=((destination, user_id) for user_id, _ in rows), - value_names=("stream_id",), - value_values=((stream_id,) for _, stream_id in rows), - ) - - # Delete all sent outbound pokes - sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id <= ? - """ - txn.execute(sql, (destination, stream_id)) - - @defer.inlineCallbacks - def add_user_signature_change_to_streams(self, from_user_id, user_ids): - """Persist that a user has made new signatures - - Args: - from_user_id (str): the user who made the signatures - user_ids (list[str]): the users who were signed - """ - - with self._device_list_id_gen.get_next() as stream_id: - yield self.db.runInteraction( - "add_user_sig_change_to_streams", - self._add_user_signature_change_txn, - from_user_id, - user_ids, - stream_id, - ) - return stream_id - - def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): - txn.call_after( - self._user_signature_stream_cache.entity_has_changed, - from_user_id, - stream_id, - ) - self.db.simple_insert_txn( - txn, - "user_signature_stream", - values={ - "stream_id": stream_id, - "from_user_id": from_user_id, - "user_ids": json.dumps(user_ids), - }, - ) - - def get_device_stream_token(self): - return self._device_list_id_gen.get_current_token() - - @trace - @defer.inlineCallbacks - def get_user_devices_from_cache(self, query_list): - """Get the devices (and keys if any) for remote users from the cache. - - Args: - query_list(list): List of (user_id, device_ids), if device_ids is - falsey then return all device ids for that user. - - Returns: - (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is - a set of user_ids and results_map is a mapping of - user_id -> device_id -> device_info - """ - user_ids = {user_id for user_id, _ in query_list} - user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) - - # We go and check if any of the users need to have their device lists - # resynced. If they do then we remove them from the cached list. - users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( - user_ids - ) - user_ids_in_cache = { - user_id for user_id, stream_id in user_map.items() if stream_id - } - users_needing_resync - user_ids_not_in_cache = user_ids - user_ids_in_cache - - results = {} - for user_id, device_id in query_list: - if user_id not in user_ids_in_cache: - continue - - if device_id: - device = yield self._get_cached_user_device(user_id, device_id) - results.setdefault(user_id, {})[device_id] = device - else: - results[user_id] = yield self.get_cached_devices_for_user(user_id) - - set_tag("in_cache", results) - set_tag("not_in_cache", user_ids_not_in_cache) - - return user_ids_not_in_cache, results - - @cachedInlineCallbacks(num_args=2, tree=True) - def _get_cached_user_device(self, user_id, device_id): - content = yield self.db.simple_select_one_onecol( - table="device_lists_remote_cache", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="content", - desc="_get_cached_user_device", - ) - return db_to_json(content) - - @cachedInlineCallbacks() - def get_cached_devices_for_user(self, user_id): - devices = yield self.db.simple_select_list( - table="device_lists_remote_cache", - keyvalues={"user_id": user_id}, - retcols=("device_id", "content"), - desc="get_cached_devices_for_user", - ) - return { - device["device_id"]: db_to_json(device["content"]) for device in devices - } - - def get_devices_with_keys_by_user(self, user_id): - """Get all devices (with any device keys) for a user - - Returns: - (stream_id, devices) - """ - return self.db.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, - user_id, - ) - - def _get_devices_with_keys_by_user_txn(self, txn, user_id): - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn( - txn, [(user_id, None)], include_all_devices=True - ) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in user_devices.items(): - result = {"device_id": device_id} - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - - def get_users_whose_devices_changed(self, from_key, user_ids): - """Get set of users whose devices have changed since `from_key` that - are in the given list of user_ids. - - Args: - from_key (str): The device lists stream token - user_ids (Iterable[str]) - - Returns: - Deferred[set[str]]: The set of user_ids whose devices have changed - since `from_key` - """ - from_key = int(from_key) - - # Get set of users who *may* have changed. Users not in the returned - # list have definitely not changed. - to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) - - if not to_check: - return defer.succeed(set()) - - def _get_users_whose_devices_changed_txn(txn): - changes = set() - - sql = """ - SELECT DISTINCT user_id FROM device_lists_stream - WHERE stream_id > ? - AND - """ - - for chunk in batch_iter(to_check, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "user_id", chunk - ) - txn.execute(sql + clause, (from_key,) + tuple(args)) - changes.update(user_id for user_id, in txn) - - return changes - - return self.db.runInteraction( - "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn - ) - - @defer.inlineCallbacks - def get_users_whose_signatures_changed(self, user_id, from_key): - """Get the users who have new cross-signing signatures made by `user_id` since - `from_key`. - - Args: - user_id (str): the user who made the signatures - from_key (str): The device lists stream token - """ - from_key = int(from_key) - if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): - sql = """ - SELECT DISTINCT user_ids FROM user_signature_stream - WHERE from_user_id = ? AND stream_id > ? - """ - rows = yield self.db.execute( - "get_users_whose_signatures_changed", None, sql, user_id, from_key - ) - return {user for row in rows for user in db_to_json(row[0])} - else: - return set() - - async def get_all_device_list_changes_for_remotes( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for device lists replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - if last_id == current_id: - return [], current_id, False - - def _get_all_device_list_changes_for_remotes(txn): - # This query Does The Right Thing where it'll correctly apply the - # bounds to the inner queries. - sql = """ - SELECT stream_id, entity FROM ( - SELECT stream_id, user_id AS entity FROM device_lists_stream - UNION ALL - SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes - ) AS e - WHERE ? < stream_id AND stream_id <= ? - LIMIT ? - """ - - txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db.runInteraction( - "get_all_device_list_changes_for_remotes", - _get_all_device_list_changes_for_remotes, - ) - - @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id): - """Get the last stream_id we got for a user. May be None if we haven't - got any information for them. - """ - return self.db.simple_select_one_onecol( - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - retcol="stream_id", - desc="get_device_list_last_stream_id_for_remote", - allow_none=True, - ) - - @cachedList( - cached_method_name="get_device_list_last_stream_id_for_remote", - list_name="user_ids", - inlineCallbacks=True, - ) - def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self.db.simple_select_many_batch( - table="device_lists_remote_extremeties", - column="user_id", - iterable=user_ids, - retcols=("user_id", "stream_id"), - desc="get_device_list_last_stream_id_for_remotes", - ) - - results = {user_id: None for user_id in user_ids} - results.update({row["user_id"]: row["stream_id"] for row in rows}) - - return results - - @defer.inlineCallbacks - def get_user_ids_requiring_device_list_resync( - self, user_ids: Optional[Collection[str]] = None, - ) -> Set[str]: - """Given a list of remote users return the list of users that we - should resync the device lists for. If None is given instead of a list, - return every user that we should resync the device lists for. - - Returns: - The IDs of users whose device lists need resync. - """ - if user_ids: - rows = yield self.db.simple_select_many_batch( - table="device_lists_remote_resync", - column="user_id", - iterable=user_ids, - retcols=("user_id",), - desc="get_user_ids_requiring_device_list_resync_with_iterable", - ) - else: - rows = yield self.db.simple_select_list( - table="device_lists_remote_resync", - keyvalues=None, - retcols=("user_id",), - desc="get_user_ids_requiring_device_list_resync", - ) - - return {row["user_id"] for row in rows} - - def mark_remote_user_device_cache_as_stale(self, user_id: str): - """Records that the server has reason to believe the cache of the devices - for the remote users is out of date. - """ - return self.db.simple_upsert( - table="device_lists_remote_resync", - keyvalues={"user_id": user_id}, - values={}, - insertion_values={"added_ts": self._clock.time_msec()}, - desc="make_remote_user_device_cache_as_stale", - ) - - def mark_remote_user_device_list_as_unsubscribed(self, user_id): - """Mark that we no longer track device lists for remote user. - """ - - def _mark_remote_user_device_list_as_unsubscribed_txn(txn): - self.db.simple_delete_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - ) - self._invalidate_cache_and_stream( - txn, self.get_device_list_last_stream_id_for_remote, (user_id,) - ) - - return self.db.runInteraction( - "mark_remote_user_device_list_as_unsubscribed", - _mark_remote_user_device_list_as_unsubscribed_txn, - ) - - -class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) - - self.db.updates.register_background_index_update( - "device_lists_stream_idx", - index_name="device_lists_stream_user_id", - table="device_lists_stream", - columns=["user_id", "device_id"], - ) - - # create a unique index on device_lists_remote_cache - self.db.updates.register_background_index_update( - "device_lists_remote_cache_unique_idx", - index_name="device_lists_remote_cache_unique_id", - table="device_lists_remote_cache", - columns=["user_id", "device_id"], - unique=True, - ) - - # And one on device_lists_remote_extremeties - self.db.updates.register_background_index_update( - "device_lists_remote_extremeties_unique_idx", - index_name="device_lists_remote_extremeties_unique_idx", - table="device_lists_remote_extremeties", - columns=["user_id"], - unique=True, - ) - - # once they complete, we can remove the old non-unique indexes. - self.db.updates.register_background_update_handler( - DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, - self._drop_device_list_streams_non_unique_indexes, - ) - - # clear out duplicate device list outbound pokes - self.db.updates.register_background_update_handler( - BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, - ) - - # a pair of background updates that were added during the 1.14 release cycle, - # but replaced with 58/06dlols_unique_idx.py - self.db.updates.register_noop_background_update( - "device_lists_outbound_last_success_unique_idx", - ) - self.db.updates.register_noop_background_update( - "drop_device_lists_outbound_last_success_non_unique_idx", - ) - - @defer.inlineCallbacks - def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): - def f(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") - txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") - txn.close() - - yield self.db.runWithConnection(f) - yield self.db.updates._end_background_update( - DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES - ) - return 1 - - async def _remove_duplicate_outbound_pokes(self, progress, batch_size): - # for some reason, we have accumulated duplicate entries in - # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less - # efficient. - # - # For each duplicate, we delete all the existing rows and put one back. - - KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] - last_row = progress.get( - "last_row", - {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, - ) - - def _txn(txn): - clause, args = make_tuple_comparison_clause( - self.db.engine, [(x, last_row[x]) for x in KEY_COLS] - ) - sql = """ - SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts - FROM device_lists_outbound_pokes - WHERE %s - GROUP BY %s - HAVING count(*) > 1 - ORDER BY %s - LIMIT ? - """ % ( - clause, # WHERE - ",".join(KEY_COLS), # GROUP BY - ",".join(KEY_COLS), # ORDER BY - ) - txn.execute(sql, args + [batch_size]) - rows = self.db.cursor_to_dict(txn) - - row = None - for row in rows: - self.db.simple_delete_txn( - txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, - ) - - row["sent"] = False - self.db.simple_insert_txn( - txn, "device_lists_outbound_pokes", row, - ) - - if row: - self.db.updates._background_update_progress_txn( - txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, - ) - - return len(rows) - - rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn) - - if not rows: - await self.db.updates._end_background_update( - BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES - ) - - return rows - - -class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): - super(DeviceStore, self).__init__(database, db_conn, hs) - - # Map of (user_id, device_id) -> bool. If there is an entry that implies - # the device exists. - self.device_id_exists_cache = Cache( - name="device_id_exists", keylen=2, max_entries=10000 - ) - - self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) - - @defer.inlineCallbacks - def store_device(self, user_id, device_id, initial_device_display_name): - """Ensure the given device is known; add it to the store if not - - Args: - user_id (str): id of user associated with the device - device_id (str): id of device - initial_device_display_name (str): initial displayname of the - device. Ignored if device exists. - Returns: - defer.Deferred: boolean whether the device was inserted or an - existing device existed with that ID. - Raises: - StoreError: if the device is already in use - """ - key = (user_id, device_id) - if self.device_id_exists_cache.get(key, None): - return False - - try: - inserted = yield self.db.simple_insert( - "devices", - values={ - "user_id": user_id, - "device_id": device_id, - "display_name": initial_device_display_name, - "hidden": False, - }, - desc="store_device", - or_ignore=True, - ) - if not inserted: - # if the device already exists, check if it's a real device, or - # if the device ID is reserved by something else - hidden = yield self.db.simple_select_one_onecol( - "devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="hidden", - ) - if hidden: - raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - self.device_id_exists_cache.prefill(key, True) - return inserted - except StoreError: - raise - except Exception as e: - logger.error( - "store_device with device_id=%s(%r) user_id=%s(%r)" - " display_name=%s(%r) failed: %s", - type(device_id).__name__, - device_id, - type(user_id).__name__, - user_id, - type(initial_device_display_name).__name__, - initial_device_display_name, - e, - ) - raise StoreError(500, "Problem storing device.") - - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): - """Delete a device. - - Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to delete - Returns: - defer.Deferred - """ - yield self.db.simple_delete_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - desc="delete_device", - ) - - self.device_id_exists_cache.invalidate((user_id, device_id)) - - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): - """Deletes several devices. - - Args: - user_id (str): The ID of the user which owns the devices - device_ids (list): The IDs of the devices to delete - Returns: - defer.Deferred - """ - yield self.db.simple_delete_many( - table="devices", - column="device_id", - iterable=device_ids, - keyvalues={"user_id": user_id, "hidden": False}, - desc="delete_devices", - ) - for device_id in device_ids: - self.device_id_exists_cache.invalidate((user_id, device_id)) - - def update_device(self, user_id, device_id, new_display_name=None): - """Update a device. Only updates the device if it is not marked as - hidden. - - Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to update - new_display_name (str|None): new displayname for device; None - to leave unchanged - Raises: - StoreError: if the device is not found - Returns: - defer.Deferred - """ - updates = {} - if new_display_name is not None: - updates["display_name"] = new_display_name - if not updates: - return defer.succeed(None) - return self.db.simple_update_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - updatevalues=updates, - desc="update_device", - ) - - def update_remote_device_list_cache_entry( - self, user_id, device_id, content, stream_id - ): - """Updates a single device in the cache of a remote user's devicelist. - - Note: assumes that we are the only thread that can be updating this user's - device list. - - Args: - user_id (str): User to update device list for - device_id (str): ID of decivice being updated - content (dict): new data on this device - stream_id (int): the version of the device list - - Returns: - Deferred[None] - """ - return self.db.runInteraction( - "update_remote_device_list_cache_entry", - self._update_remote_device_list_cache_entry_txn, - user_id, - device_id, - content, - stream_id, - ) - - def _update_remote_device_list_cache_entry_txn( - self, txn, user_id, device_id, content, stream_id - ): - if content.get("deleted"): - self.db.simple_delete_txn( - txn, - table="device_lists_remote_cache", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - - txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) - else: - self.db.simple_upsert_txn( - txn, - table="device_lists_remote_cache", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"content": json.dumps(content)}, - # we don't need to lock, because we assume we are the only thread - # updating this user's devices. - lock=False, - ) - - txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) - txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) - txn.call_after( - self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) - ) - - self.db.simple_upsert_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - values={"stream_id": stream_id}, - # again, we can assume we are the only thread updating this user's - # extremity. - lock=False, - ) - - def update_remote_device_list_cache(self, user_id, devices, stream_id): - """Replace the entire cache of the remote user's devices. - - Note: assumes that we are the only thread that can be updating this user's - device list. - - Args: - user_id (str): User to update device list for - devices (list[dict]): list of device objects supplied over federation - stream_id (int): the version of the device list - - Returns: - Deferred[None] - """ - return self.db.runInteraction( - "update_remote_device_list_cache", - self._update_remote_device_list_cache_txn, - user_id, - devices, - stream_id, - ) - - def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self.db.simple_delete_txn( - txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} - ) - - self.db.simple_insert_many_txn( - txn, - table="device_lists_remote_cache", - values=[ - { - "user_id": user_id, - "device_id": content["device_id"], - "content": json.dumps(content), - } - for content in devices - ], - ) - - txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) - txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) - txn.call_after( - self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) - ) - - self.db.simple_upsert_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - values={"stream_id": stream_id}, - # we don't need to lock, because we can assume we are the only thread - # updating this user's extremity. - lock=False, - ) - - # If we're replacing the remote user's device list cache presumably - # we've done a full resync, so we remove the entry that says we need - # to resync - self.db.simple_delete_txn( - txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, - ) - - @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_ids, hosts): - """Persist that a user's devices have been updated, and which hosts - (if any) should be poked. - """ - if not device_ids: - return - - with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: - yield self.db.runInteraction( - "add_device_change_to_stream", - self._add_device_change_to_stream_txn, - user_id, - device_ids, - stream_ids, - ) - - if not hosts: - return stream_ids[-1] - - context = get_active_span_text_map() - with self._device_list_id_gen.get_next_mult( - len(hosts) * len(device_ids) - ) as stream_ids: - yield self.db.runInteraction( - "add_device_outbound_poke_to_stream", - self._add_device_outbound_poke_to_stream_txn, - user_id, - device_ids, - hosts, - stream_ids, - context, - ) - - return stream_ids[-1] - - def _add_device_change_to_stream_txn( - self, - txn: LoggingTransaction, - user_id: str, - device_ids: Collection[str], - stream_ids: List[str], - ): - txn.call_after( - self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], - ) - - min_stream_id = stream_ids[0] - - # Delete older entries in the table, as we really only care about - # when the latest change happened. - txn.executemany( - """ - DELETE FROM device_lists_stream - WHERE user_id = ? AND device_id = ? AND stream_id < ? - """, - [(user_id, device_id, min_stream_id) for device_id in device_ids], - ) - - self.db.simple_insert_many_txn( - txn, - table="device_lists_stream", - values=[ - {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} - for stream_id, device_id in zip(stream_ids, device_ids) - ], - ) - - def _add_device_outbound_poke_to_stream_txn( - self, txn, user_id, device_ids, hosts, stream_ids, context, - ): - for host in hosts: - txn.call_after( - self._device_list_federation_stream_cache.entity_has_changed, - host, - stream_ids[-1], - ) - - now = self._clock.time_msec() - next_stream_id = iter(stream_ids) - - self.db.simple_insert_many_txn( - txn, - table="device_lists_outbound_pokes", - values=[ - { - "destination": destination, - "stream_id": next(next_stream_id), - "user_id": user_id, - "device_id": device_id, - "sent": False, - "ts": now, - "opentracing_context": json.dumps(context) - if whitelisted_homeserver(destination) - else "{}", - } - for destination in hosts - for device_id in device_ids - ], - ) - - def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): - """Delete old entries out of the device_lists_outbound_pokes to ensure - that we don't fill up due to dead servers. - - Normally, we try to send device updates as a delta since a previous known point: - this is done by setting the prev_id in the m.device_list_update EDU. However, - for that to work, we have to have a complete record of each change to - each device, which can add up to quite a lot of data. - - An alternative mechanism is that, if the remote server sees that it has missed - an entry in the stream_id sequence for a given user, it will request a full - list of that user's devices. Hence, we can reduce the amount of data we have to - store (and transmit in some future transaction), by clearing almost everything - for a given destination out of the database, and having the remote server - resync. - - All we need to do is make sure we keep at least one row for each - (user, destination) pair, to remind us to send a m.device_list_update EDU for - that user when the destination comes back. It doesn't matter which device - we keep. - """ - yesterday = self._clock.time_msec() - prune_age - - def _prune_txn(txn): - # look for (user, destination) pairs which have an update older than - # the cutoff. - # - # For each pair, we also need to know the most recent stream_id, and - # an arbitrary device_id at that stream_id. - select_sql = """ - SELECT - dlop1.destination, - dlop1.user_id, - MAX(dlop1.stream_id) AS stream_id, - (SELECT MIN(dlop2.device_id) AS device_id FROM - device_lists_outbound_pokes dlop2 - WHERE dlop2.destination = dlop1.destination AND - dlop2.user_id=dlop1.user_id AND - dlop2.stream_id=MAX(dlop1.stream_id) - ) - FROM device_lists_outbound_pokes dlop1 - GROUP BY destination, user_id - HAVING min(ts) < ? AND count(*) > 1 - """ - - txn.execute(select_sql, (yesterday,)) - rows = txn.fetchall() - - if not rows: - return - - logger.info( - "Pruning old outbound device list updates for %i users/destinations: %s", - len(rows), - shortstr((row[0], row[1]) for row in rows), - ) - - # we want to keep the update with the highest stream_id for each user. - # - # there might be more than one update (with different device_ids) with the - # same stream_id, so we also delete all but one rows with the max stream id. - delete_sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND user_id = ? AND ( - stream_id < ? OR - (stream_id = ? AND device_id != ?) - ) - """ - count = 0 - for (destination, user_id, stream_id, device_id) in rows: - txn.execute( - delete_sql, (destination, user_id, stream_id, stream_id, device_id) - ) - count += txn.rowcount - - # Since we've deleted unsent deltas, we need to remove the entry - # of last successful sent so that the prev_ids are correctly set. - sql = """ - DELETE FROM device_lists_outbound_last_success - WHERE destination = ? AND user_id = ? - """ - txn.executemany(sql, ((row[0], row[1]) for row in rows)) - - logger.info("Pruned %d device list outbound pokes", count) - - return run_as_background_process( - "prune_old_outbound_device_pokes", - self.db.runInteraction, - "_prune_old_outbound_device_pokes", - _prune_txn, - ) diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py deleted file mode 100644 index e1d1bc3e05..0000000000 --- a/synapse/storage/data_stores/main/directory.py +++ /dev/null @@ -1,195 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import namedtuple -from typing import Optional - -from twisted.internet import defer - -from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore -from synapse.util.caches.descriptors import cached - -RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) - - -class DirectoryWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): - """ Get's the room_id and server list for a given room_alias - - Args: - room_alias (RoomAlias) - - Returns: - Deferred: results in namedtuple with keys "room_id" and - "servers" or None if no association can be found - """ - room_id = yield self.db.simple_select_one_onecol( - "room_aliases", - {"room_alias": room_alias.to_string()}, - "room_id", - allow_none=True, - desc="get_association_from_room_alias", - ) - - if not room_id: - return None - - servers = yield self.db.simple_select_onecol( - "room_alias_servers", - {"room_alias": room_alias.to_string()}, - "server", - desc="get_association_from_room_alias", - ) - - if not servers: - return None - - return RoomAliasMapping(room_id, room_alias.to_string(), servers) - - def get_room_alias_creator(self, room_alias): - return self.db.simple_select_one_onecol( - table="room_aliases", - keyvalues={"room_alias": room_alias}, - retcol="creator", - desc="get_room_alias_creator", - ) - - @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self.db.simple_select_onecol( - "room_aliases", - {"room_id": room_id}, - "room_alias", - desc="get_aliases_for_room", - ) - - -class DirectoryStore(DirectoryWorkerStore): - @defer.inlineCallbacks - def create_room_alias_association(self, room_alias, room_id, servers, creator=None): - """ Creates an association between a room alias and room_id/servers - - Args: - room_alias (RoomAlias) - room_id (str) - servers (list) - creator (str): Optional user_id of creator. - - Returns: - Deferred - """ - - def alias_txn(txn): - self.db.simple_insert_txn( - txn, - "room_aliases", - { - "room_alias": room_alias.to_string(), - "room_id": room_id, - "creator": creator, - }, - ) - - self.db.simple_insert_many_txn( - txn, - table="room_alias_servers", - values=[ - {"room_alias": room_alias.to_string(), "server": server} - for server in servers - ], - ) - - self._invalidate_cache_and_stream( - txn, self.get_aliases_for_room, (room_id,) - ) - - try: - ret = yield self.db.runInteraction( - "create_room_alias_association", alias_txn - ) - except self.database_engine.module.IntegrityError: - raise SynapseError( - 409, "Room alias %s already exists" % room_alias.to_string() - ) - return ret - - @defer.inlineCallbacks - def delete_room_alias(self, room_alias): - room_id = yield self.db.runInteraction( - "delete_room_alias", self._delete_room_alias_txn, room_alias - ) - - return room_id - - def _delete_room_alias_txn(self, txn, room_alias): - txn.execute( - "SELECT room_id FROM room_aliases WHERE room_alias = ?", - (room_alias.to_string(),), - ) - - res = txn.fetchone() - if res: - room_id = res[0] - else: - return None - - txn.execute( - "DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),) - ) - - txn.execute( - "DELETE FROM room_alias_servers WHERE room_alias = ?", - (room_alias.to_string(),), - ) - - self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,)) - - return room_id - - def update_aliases_for_room( - self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, - ): - """Repoint all of the aliases for a given room, to a different room. - - Args: - old_room_id: - new_room_id: - creator: The user to record as the creator of the new mapping. - If None, the creator will be left unchanged. - """ - - def _update_aliases_for_room_txn(txn): - update_creator_sql = "" - sql_params = (new_room_id, old_room_id) - if creator: - update_creator_sql = ", creator = ?" - sql_params = (new_room_id, creator, old_room_id) - - sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % ( - update_creator_sql, - ) - txn.execute(sql, sql_params) - self._invalidate_cache_and_stream( - txn, self.get_aliases_for_room, (old_room_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_aliases_for_room, (new_room_id,) - ) - - return self.db.runInteraction( - "_update_aliases_for_room_txn", _update_aliases_for_room_txn - ) diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py deleted file mode 100644 index 615364f018..0000000000 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ /dev/null @@ -1,439 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 New Vector Ltd -# Copyright 2019 Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.logging.opentracing import log_kv, trace -from synapse.storage._base import SQLBaseStore, db_to_json - - -class EndToEndRoomKeyStore(SQLBaseStore): - @defer.inlineCallbacks - def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): - """Replaces the encrypted E2E room key for a given session in a given backup - - Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set - Raises: - StoreError - """ - - yield self.db.simple_update_one( - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - }, - updatevalues={ - "first_message_index": room_key["first_message_index"], - "forwarded_count": room_key["forwarded_count"], - "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), - }, - desc="update_e2e_room_key", - ) - - @defer.inlineCallbacks - def add_e2e_room_keys(self, user_id, version, room_keys): - """Bulk add room keys to a given backup. - - Args: - user_id (str): the user whose backup we're adding to - version (str): the version ID of the backup for the set of keys we're adding to - room_keys (iterable[(str, str, dict)]): the keys to add, in the form - (roomID, sessionID, keyData) - """ - - values = [] - for (room_id, session_id, room_key) in room_keys: - values.append( - { - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - "first_message_index": room_key["first_message_index"], - "forwarded_count": room_key["forwarded_count"], - "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), - } - ) - log_kv( - { - "message": "Set room key", - "room_id": room_id, - "session_id": session_id, - "room_key": room_key, - } - ) - - yield self.db.simple_insert_many( - table="e2e_room_keys", values=values, desc="add_e2e_room_keys" - ) - - @trace - @defer.inlineCallbacks - def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): - """Bulk get the E2E room keys for a given backup, optionally filtered to a given - room, or a given session. - - Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup for the set of keys we're querying - room_id (str): Optional. the ID of the room whose keys we're querying, if any. - If not specified, we return the keys for all the rooms in the backup. - session_id (str): Optional. the session whose room_key we're querying, if any. - If specified, we also require the room_id to be specified. - If not specified, we return all the keys in this version of - the backup (or for the specified room) - - Returns: - A deferred list of dicts giving the session_data and message metadata for - these room keys. - """ - - try: - version = int(version) - except ValueError: - return {"rooms": {}} - - keyvalues = {"user_id": user_id, "version": version} - if room_id: - keyvalues["room_id"] = room_id - if session_id: - keyvalues["session_id"] = session_id - - rows = yield self.db.simple_select_list( - table="e2e_room_keys", - keyvalues=keyvalues, - retcols=( - "user_id", - "room_id", - "session_id", - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", - ), - desc="get_e2e_room_keys", - ) - - sessions = {"rooms": {}} - for row in rows: - room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) - room_entry["sessions"][row["session_id"]] = { - "first_message_index": row["first_message_index"], - "forwarded_count": row["forwarded_count"], - # is_verified must be returned to the client as a boolean - "is_verified": bool(row["is_verified"]), - "session_data": db_to_json(row["session_data"]), - } - - return sessions - - def get_e2e_room_keys_multi(self, user_id, version, room_keys): - """Get multiple room keys at a time. The difference between this function and - get_e2e_room_keys is that this function can be used to retrieve - multiple specific keys at a time, whereas get_e2e_room_keys is used for - getting all the keys in a backup version, all the keys for a room, or a - specific key. - - Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about - room_keys (dict[str, dict[str, iterable[str]]]): a map from - room ID -> {"session": [session ids]} indicating the session IDs - that we want to query - - Returns: - Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key - """ - - return self.db.runInteraction( - "get_e2e_room_keys_multi", - self._get_e2e_room_keys_multi_txn, - user_id, - version, - room_keys, - ) - - @staticmethod - def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): - if not room_keys: - return {} - - where_clauses = [] - params = [user_id, version] - for room_id, room in room_keys.items(): - sessions = list(room["sessions"]) - if not sessions: - continue - params.append(room_id) - params.extend(sessions) - where_clauses.append( - "(room_id = ? AND session_id IN (%s))" - % (",".join(["?" for _ in sessions]),) - ) - - # check if we're actually querying something - if not where_clauses: - return {} - - sql = """ - SELECT room_id, session_id, first_message_index, forwarded_count, - is_verified, session_data - FROM e2e_room_keys - WHERE user_id = ? AND version = ? AND (%s) - """ % ( - " OR ".join(where_clauses) - ) - - txn.execute(sql, params) - - ret = {} - - for row in txn: - room_id = row[0] - session_id = row[1] - ret.setdefault(room_id, {}) - ret[room_id][session_id] = { - "first_message_index": row[2], - "forwarded_count": row[3], - "is_verified": row[4], - "session_data": db_to_json(row[5]), - } - - return ret - - def count_e2e_room_keys(self, user_id, version): - """Get the number of keys in a backup version. - - Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about - """ - - return self.db.simple_select_one_onecol( - table="e2e_room_keys", - keyvalues={"user_id": user_id, "version": version}, - retcol="COUNT(*)", - desc="count_e2e_room_keys", - ) - - @trace - @defer.inlineCallbacks - def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): - """Bulk delete the E2E room keys for a given backup, optionally filtered to a given - room or a given session. - - Args: - user_id(str): the user whose backup we're deleting from - version(str): the version ID of the backup for the set of keys we're deleting - room_id(str): Optional. the ID of the room whose keys we're deleting, if any. - If not specified, we delete the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. - If specified, we also require the room_id to be specified. - If not specified, we delete all the keys in this version of - the backup (or for the specified room) - - Returns: - A deferred of the deletion transaction - """ - - keyvalues = {"user_id": user_id, "version": int(version)} - if room_id: - keyvalues["room_id"] = room_id - if session_id: - keyvalues["session_id"] = session_id - - yield self.db.simple_delete( - table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" - ) - - @staticmethod - def _get_current_version(txn, user_id): - txn.execute( - "SELECT MAX(version) FROM e2e_room_keys_versions " - "WHERE user_id=? AND deleted=0", - (user_id,), - ) - row = txn.fetchone() - if not row: - raise StoreError(404, "No current backup version") - return row[0] - - def get_e2e_room_keys_version_info(self, user_id, version=None): - """Get info metadata about a version of our room_keys backup. - - Args: - user_id(str): the user whose backup we're querying - version(str): Optional. the version ID of the backup we're querying about - If missing, we return the information about the current version. - Raises: - StoreError: with code 404 if there are no e2e_room_keys_versions present - Returns: - A deferred dict giving the info metadata for this backup version, with - fields including: - version(str) - algorithm(str) - auth_data(object): opaque dict supplied by the client - etag(int): tag of the keys in the backup - """ - - def _get_e2e_room_keys_version_info_txn(txn): - if version is None: - this_version = self._get_current_version(txn, user_id) - else: - try: - this_version = int(version) - except ValueError: - # Our versions are all ints so if we can't convert it to an integer, - # it isn't there. - raise StoreError(404, "No row found") - - result = self.db.simple_select_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data", "etag"), - ) - result["auth_data"] = db_to_json(result["auth_data"]) - result["version"] = str(result["version"]) - if result["etag"] is None: - result["etag"] = 0 - return result - - return self.db.runInteraction( - "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn - ) - - @trace - def create_e2e_room_keys_version(self, user_id, info): - """Atomically creates a new version of this user's e2e_room_keys store - with the given version info. - - Args: - user_id(str): the user whose backup we're creating a version - info(dict): the info about the backup version to be created - - Returns: - A deferred string for the newly created version ID - """ - - def _create_e2e_room_keys_version_txn(txn): - txn.execute( - "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", - (user_id,), - ) - current_version = txn.fetchone()[0] - if current_version is None: - current_version = "0" - - new_version = str(int(current_version) + 1) - - self.db.simple_insert_txn( - txn, - table="e2e_room_keys_versions", - values={ - "user_id": user_id, - "version": new_version, - "algorithm": info["algorithm"], - "auth_data": json.dumps(info["auth_data"]), - }, - ) - - return new_version - - return self.db.runInteraction( - "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn - ) - - @trace - def update_e2e_room_keys_version( - self, user_id, version, info=None, version_etag=None - ): - """Update a given backup version - - Args: - user_id(str): the user whose backup version we're updating - version(str): the version ID of the backup version we're updating - info (dict): the new backup version info to store. If None, then - the backup version info is not updated - version_etag (Optional[int]): etag of the keys in the backup. If - None, then the etag is not updated - """ - updatevalues = {} - - if info is not None and "auth_data" in info: - updatevalues["auth_data"] = json.dumps(info["auth_data"]) - if version_etag is not None: - updatevalues["etag"] = version_etag - - if updatevalues: - return self.db.simple_update( - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": version}, - updatevalues=updatevalues, - desc="update_e2e_room_keys_version", - ) - - @trace - def delete_e2e_room_keys_version(self, user_id, version=None): - """Delete a given backup version of the user's room keys. - Doesn't delete their actual key data. - - Args: - user_id(str): the user whose backup version we're deleting - version(str): Optional. the version ID of the backup version we're deleting - If missing, we delete the current backup version info. - Raises: - StoreError: with code 404 if there are no e2e_room_keys_versions present, - or if the version requested doesn't exist. - """ - - def _delete_e2e_room_keys_version_txn(txn): - if version is None: - this_version = self._get_current_version(txn, user_id) - if this_version is None: - raise StoreError(404, "No current backup version") - else: - this_version = version - - self.db.simple_delete_txn( - txn, - table="e2e_room_keys", - keyvalues={"user_id": user_id, "version": this_version}, - ) - - return self.db.simple_update_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": this_version}, - updatevalues={"deleted": 1}, - ) - - return self.db.runInteraction( - "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn - ) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py deleted file mode 100644 index 317c07a829..0000000000 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ /dev/null @@ -1,746 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Dict, List, Tuple - -from canonicaljson import encode_canonical_json, json - -from twisted.enterprise.adbapi import Connection -from twisted.internet import defer - -from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import make_in_list_sql_clause -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.iterutils import batch_iter - - -class EndToEndKeyWorkerStore(SQLBaseStore): - @trace - @defer.inlineCallbacks - def get_e2e_device_keys( - self, query_list, include_all_devices=False, include_deleted_devices=False - ): - """Fetch a list of device keys. - Args: - query_list(list): List of pairs of user_ids and device_ids. - include_all_devices (bool): whether to include entries for devices - that don't have device keys - include_deleted_devices (bool): whether to include null entries for - devices which no longer exist (but were in the query_list). - This option only takes effect if include_all_devices is true. - Returns: - Dict mapping from user-id to dict mapping from device_id to - key data. The key data will be a dict in the same format as the - DeviceKeys type returned by POST /_matrix/client/r0/keys/query. - """ - set_tag("query_list", query_list) - if not query_list: - return {} - - results = yield self.db.runInteraction( - "get_e2e_device_keys", - self._get_e2e_device_keys_txn, - query_list, - include_all_devices, - include_deleted_devices, - ) - - # Build the result structure, un-jsonify the results, and add the - # "unsigned" section - rv = {} - for user_id, device_keys in results.items(): - rv[user_id] = {} - for device_id, device_info in device_keys.items(): - r = db_to_json(device_info.pop("key_json")) - r["unsigned"] = {} - display_name = device_info["device_display_name"] - if display_name is not None: - r["unsigned"]["device_display_name"] = display_name - if "signatures" in device_info: - for sig_user_id, sigs in device_info["signatures"].items(): - r.setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - rv[user_id][device_id] = r - - return rv - - @trace - def _get_e2e_device_keys_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False - ): - set_tag("include_all_devices", include_all_devices) - set_tag("include_deleted_devices", include_deleted_devices) - - query_clauses = [] - query_params = [] - signature_query_clauses = [] - signature_query_params = [] - - if include_all_devices is False: - include_deleted_devices = False - - if include_deleted_devices: - deleted_devices = set(query_list) - - for (user_id, device_id) in query_list: - query_clause = "user_id = ?" - query_params.append(user_id) - signature_query_clause = "target_user_id = ?" - signature_query_params.append(user_id) - - if device_id is not None: - query_clause += " AND device_id = ?" - query_params.append(device_id) - signature_query_clause += " AND target_device_id = ?" - signature_query_params.append(device_id) - - signature_query_clause += " AND user_id = ?" - signature_query_params.append(user_id) - - query_clauses.append(query_clause) - signature_query_clauses.append(signature_query_clause) - - sql = ( - "SELECT user_id, device_id, " - " d.display_name AS device_display_name, " - " k.key_json" - " FROM devices d" - " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" - " WHERE %s AND NOT d.hidden" - ) % ( - "LEFT" if include_all_devices else "INNER", - " OR ".join("(" + q + ")" for q in query_clauses), - ) - - txn.execute(sql, query_params) - rows = self.db.cursor_to_dict(txn) - - result = {} - for row in rows: - if include_deleted_devices: - deleted_devices.remove((row["user_id"], row["device_id"])) - result.setdefault(row["user_id"], {})[row["device_id"]] = row - - if include_deleted_devices: - for user_id, device_id in deleted_devices: - result.setdefault(user_id, {})[device_id] = None - - # get signatures on the device - signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) - - txn.execute(signature_sql, signature_query_params) - rows = self.db.cursor_to_dict(txn) - - # add each cross-signing signature to the correct device in the result dict. - for row in rows: - signing_user_id = row["user_id"] - signing_key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] - signature = row["signature"] - - target_user_result = result.get(target_user_id) - if not target_user_result: - continue - - target_device_result = target_user_result.get(target_device_id) - if not target_device_result: - # note that target_device_result will be None for deleted devices. - continue - - target_device_signatures = target_device_result.setdefault("signatures", {}) - signing_user_signatures = target_device_signatures.setdefault( - signing_user_id, {} - ) - signing_user_signatures[signing_key_id] = signature - - log_kv(result) - return result - - @defer.inlineCallbacks - def get_e2e_one_time_keys(self, user_id, device_id, key_ids): - """Retrieve a number of one-time keys for a user - - Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - key_ids(list[str]): list of key ids (excluding algorithm) to - retrieve - - Returns: - deferred resolving to Dict[(str, str), str]: map from (algorithm, - key_id) to json string for key - """ - - rows = yield self.db.simple_select_many_batch( - table="e2e_one_time_keys_json", - column="key_id", - iterable=key_ids, - retcols=("algorithm", "key_id", "key_json"), - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="add_e2e_one_time_keys_check", - ) - result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} - log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) - return result - - @defer.inlineCallbacks - def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): - """Insert some new one time keys for a device. Errors if any of the - keys already exist. - - Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - time_now(long): insertion time to record (ms since epoch) - new_keys(iterable[(str, str, str)]: keys to add - each a tuple of - (algorithm, key_id, key json) - """ - - def _add_e2e_one_time_keys(txn): - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("new_keys", new_keys) - # We are protected from race between lookup and insertion due to - # a unique constraint. If there is a race of two calls to - # `add_e2e_one_time_keys` then they'll conflict and we will only - # insert one set. - self.db.simple_insert_many_txn( - txn, - table="e2e_one_time_keys_json", - values=[ - { - "user_id": user_id, - "device_id": device_id, - "algorithm": algorithm, - "key_id": key_id, - "ts_added_ms": time_now, - "key_json": json_bytes, - } - for algorithm, key_id, json_bytes in new_keys - ], - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - - yield self.db.runInteraction( - "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys - ) - - @cached(max_entries=10000) - def count_e2e_one_time_keys(self, user_id, device_id): - """ Count the number of one time keys the server has for a device - Returns: - Dict mapping from algorithm to number of keys for that algorithm. - """ - - def _count_e2e_one_time_keys(txn): - sql = ( - "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ?" - " GROUP BY algorithm" - ) - txn.execute(sql, (user_id, device_id)) - result = {} - for algorithm, key_count in txn: - result[algorithm] = key_count - return result - - return self.db.runInteraction( - "count_e2e_one_time_keys", _count_e2e_one_time_keys - ) - - @defer.inlineCallbacks - def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): - """Returns a user's cross-signing key. - - Args: - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being requested: either 'master' - for a master key, 'self_signing' for a self-signing key, or - 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on - the self-signing key will be included in the result - - Returns: - dict of the key data or None if not found - """ - res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) - user_keys = res.get(user_id) - if not user_keys: - return None - return user_keys.get(key_type) - - @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id): - """Dummy function. Only used to make a cache for - _get_bare_e2e_cross_signing_keys_bulk. - """ - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_bare_e2e_cross_signing_keys", - list_name="user_ids", - num_args=1, - ) - def _get_bare_e2e_cross_signing_keys_bulk( - self, user_ids: List[str] - ) -> Dict[str, Dict[str, dict]]: - """Returns the cross-signing keys for a set of users. The output of this - function should be passed to _get_e2e_cross_signing_signatures_txn if - the signatures for the calling user need to be fetched. - - Args: - user_ids (list[str]): the users whose keys are being requested - - Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. If a user's cross-signing keys were not found, either - their user ID will not be in the dict, or their user ID will map - to None. - - """ - return self.db.runInteraction( - "get_bare_e2e_cross_signing_keys_bulk", - self._get_bare_e2e_cross_signing_keys_bulk_txn, - user_ids, - ) - - def _get_bare_e2e_cross_signing_keys_bulk_txn( - self, txn: Connection, user_ids: List[str], - ) -> Dict[str, Dict[str, dict]]: - """Returns the cross-signing keys for a set of users. The output of this - function should be passed to _get_e2e_cross_signing_signatures_txn if - the signatures for the calling user need to be fetched. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_ids (list[str]): the users whose keys are being requested - - Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. If a user's cross-signing keys were not found, their user - ID will not be in the dict. - - """ - result = {} - - for user_chunk in batch_iter(user_ids, 100): - clause, params = make_in_list_sql_clause( - txn.database_engine, "k.user_id", user_chunk - ) - sql = ( - """ - SELECT k.user_id, k.keytype, k.keydata, k.stream_id - FROM e2e_cross_signing_keys k - INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id - FROM e2e_cross_signing_keys - GROUP BY user_id, keytype) s - USING (user_id, stream_id, keytype) - WHERE - """ - + clause - ) - - txn.execute(sql, params) - rows = self.db.cursor_to_dict(txn) - - for row in rows: - user_id = row["user_id"] - key_type = row["keytype"] - key = db_to_json(row["keydata"]) - user_info = result.setdefault(user_id, {}) - user_info[key_type] = key - - return result - - def _get_e2e_cross_signing_signatures_txn( - self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, - ) -> Dict[str, Dict[str, dict]]: - """Returns the cross-signing signatures made by a user on a set of keys. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - keys (dict[str, dict[str, dict]]): a map of user ID to key type to - key data. This dict will be modified to add signatures. - from_user_id (str): fetch the signatures made by this user - - Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. The return value will be the same as the keys argument, - with the modifications included. - """ - - # find out what cross-signing keys (a.k.a. devices) we need to get - # signatures for. This is a map of (user_id, device_id) to key type - # (device_id is the key's public part). - devices = {} - - for user_id, user_info in keys.items(): - if user_info is None: - continue - for key_type, key in user_info.items(): - device_id = None - for k in key["keys"].values(): - device_id = k - devices[(user_id, device_id)] = key_type - - for batch in batch_iter(devices.keys(), size=100): - sql = """ - SELECT target_user_id, target_device_id, key_id, signature - FROM e2e_cross_signing_signatures - WHERE user_id = ? - AND (%s) - """ % ( - " OR ".join( - "(target_user_id = ? AND target_device_id = ?)" for _ in batch - ) - ) - query_params = [from_user_id] - for item in batch: - # item is a (user_id, device_id) tuple - query_params.extend(item) - - txn.execute(sql, query_params) - rows = self.db.cursor_to_dict(txn) - - # and add the signatures to the appropriate keys - for row in rows: - key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] - key_type = devices[(target_user_id, target_device_id)] - # We need to copy everything, because the result may have come - # from the cache. dict.copy only does a shallow copy, so we - # need to recursively copy the dicts that will be modified. - user_info = keys[target_user_id] = keys[target_user_id].copy() - target_user_key = user_info[key_type] = user_info[key_type].copy() - if "signatures" in target_user_key: - signatures = target_user_key["signatures"] = target_user_key[ - "signatures" - ].copy() - if from_user_id in signatures: - user_sigs = signatures[from_user_id] = signatures[from_user_id] - user_sigs[key_id] = row["signature"] - else: - signatures[from_user_id] = {key_id: row["signature"]} - else: - target_user_key["signatures"] = { - from_user_id: {key_id: row["signature"]} - } - - return keys - - @defer.inlineCallbacks - def get_e2e_cross_signing_keys_bulk( - self, user_ids: List[str], from_user_id: str = None - ) -> defer.Deferred: - """Returns the cross-signing keys for a set of users. - - Args: - user_ids (list[str]): the users whose keys are being requested - from_user_id (str): if specified, signatures made by this user on - the self-signing keys will be included in the result - - Returns: - Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to - key data. If a user's cross-signing keys were not found, either - their user ID will not be in the dict, or their user ID will map - to None. - """ - - result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) - - if from_user_id: - result = yield self.db.runInteraction( - "get_e2e_cross_signing_signatures", - self._get_e2e_cross_signing_signatures_txn, - result, - from_user_id, - ) - - return result - - async def get_all_user_signature_changes_for_remotes( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for groups replication stream. - - Note that the user signature stream represents when a user signs their - device with their user-signing key, which is not published to other - users or servers, so no `destination` is needed in the returned - list. However, this is needed to poke workers. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - if last_id == current_id: - return [], current_id, False - - def _get_all_user_signature_changes_for_remotes_txn(txn): - sql = """ - SELECT stream_id, from_user_id AS user_id - FROM user_signature_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - txn.execute(sql, (last_id, current_id, limit)) - - updates = [(row[0], (row[1:])) for row in txn] - - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db.runInteraction( - "get_all_user_signature_changes_for_remotes", - _get_all_user_signature_changes_for_remotes_txn, - ) - - -class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): - """Stores device keys for a device. Returns whether there was a change - or the keys were already in the database. - """ - - def _set_e2e_device_keys_txn(txn): - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("time_now", time_now) - set_tag("device_keys", device_keys) - - old_key_json = self.db.simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="key_json", - allow_none=True, - ) - - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") - - if old_key_json == new_key_json: - log_kv({"Message": "Device key already stored."}) - return False - - self.db.simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"ts_added_ms": time_now, "key_json": new_key_json}, - ) - log_kv({"message": "Device keys stored."}) - return True - - return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) - - def claim_e2e_one_time_keys(self, query_list): - """Take a list of one time keys out of the database""" - - @trace - def _claim_e2e_one_time_keys(txn): - sql = ( - "SELECT key_id, key_json FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" - ) - result = {} - delete = [] - for user_id, device_id, algorithm in query_list: - user_result = result.setdefault(user_id, {}) - device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn: - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND key_id = ?" - ) - for user_id, device_id, algorithm, key_id in delete: - log_kv( - { - "message": "Executing claim e2e_one_time_keys transaction on database." - } - ) - txn.execute(sql, (user_id, device_id, algorithm, key_id)) - log_kv({"message": "finished executing and invalidating cache"}) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - return result - - return self.db.runInteraction( - "claim_e2e_one_time_keys", _claim_e2e_one_time_keys - ) - - def delete_e2e_keys_by_device(self, user_id, device_id): - def delete_e2e_keys_by_device_txn(txn): - log_kv( - { - "message": "Deleting keys for device", - "device_id": device_id, - "user_id": user_id, - } - ) - self.db.simple_delete_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self.db.simple_delete_txn( - txn, - table="e2e_one_time_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - - return self.db.runInteraction( - "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn - ) - - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): - """Set a user's cross-signing key. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_id (str): the user to set the signing key for - key_type (str): the type of key that is being set: either 'master' - for a master key, 'self_signing' for a self-signing key, or - 'user_signing' for a user-signing key - key (dict): the key data - """ - # the 'key' dict will look something like: - # { - # "user_id": "@alice:example.com", - # "usage": ["self_signing"], - # "keys": { - # "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key", - # }, - # "signatures": { - # "@alice:example.com": { - # "ed25519:base64+master+public+key": "base64+signature" - # } - # } - # } - # The "keys" property must only have one entry, which will be the public - # key, so we just grab the first value in there - pubkey = next(iter(key["keys"].values())) - - # The cross-signing keys need to occupy the same namespace as devices, - # since signatures are identified by device ID. So add an entry to the - # device table to make sure that we don't have a collision with device - # IDs. - # We only need to do this for local users, since remote servers should be - # responsible for checking this for their own users. - if self.hs.is_mine_id(user_id): - self.db.simple_insert_txn( - txn, - "devices", - values={ - "user_id": user_id, - "device_id": pubkey, - "display_name": key_type + " signing key", - "hidden": True, - }, - ) - - # and finally, store the key itself - with self._cross_signing_id_gen.get_next() as stream_id: - self.db.simple_insert_txn( - txn, - "e2e_cross_signing_keys", - values={ - "user_id": user_id, - "keytype": key_type, - "keydata": json.dumps(key), - "stream_id": stream_id, - }, - ) - - self._invalidate_cache_and_stream( - txn, self._get_bare_e2e_cross_signing_keys, (user_id,) - ) - - def set_e2e_cross_signing_key(self, user_id, key_type, key): - """Set a user's cross-signing key. - - Args: - user_id (str): the user to set the user-signing key for - key_type (str): the type of cross-signing key to set - key (dict): the key data - """ - return self.db.runInteraction( - "add_e2e_cross_signing_key", - self._set_e2e_cross_signing_key_txn, - user_id, - key_type, - key, - ) - - def store_e2e_cross_signing_signatures(self, user_id, signatures): - """Stores cross-signing signatures. - - Args: - user_id (str): the user who made the signatures - signatures (iterable[SignatureListItem]): signatures to add - """ - return self.db.simple_insert_many( - "e2e_cross_signing_signatures", - [ - { - "user_id": user_id, - "key_id": item.signing_key_id, - "target_user_id": item.target_user_id, - "target_device_id": item.target_device_id, - "signature": item.signature, - } - for item in signatures - ], - "add_e2e_signing_key", - ) diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py deleted file mode 100644 index a6bb3221ff..0000000000 --- a/synapse/storage/data_stores/main/event_federation.py +++ /dev/null @@ -1,724 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import itertools -import logging -from queue import Empty, PriorityQueue -from typing import Dict, List, Optional, Set, Tuple - -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.signatures import SignatureWorkerStore -from synapse.storage.database import Database -from synapse.util.caches.descriptors import cached -from synapse.util.iterutils import batch_iter - -logger = logging.getLogger(__name__) - - -class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): - """Get auth events for given event_ids. The events *must* be state events. - - Args: - event_ids (list): state events - include_given (bool): include the given events in result - - Returns: - list of events - """ - return self.get_auth_chain_ids( - event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) - - def get_auth_chain_ids( - self, - event_ids: List[str], - include_given: bool = False, - ignore_events: Optional[Set[str]] = None, - ): - """Get auth events for given event_ids. The events *must* be state events. - - Args: - event_ids: state events - include_given: include the given events in result - ignore_events: Set of events to exclude from the returned auth - chain. This is useful if the caller will just discard the - given events anyway, and saves us from figuring out their auth - chains if not required. - - Returns: - list of event_ids - """ - return self.db.runInteraction( - "get_auth_chain_ids", - self._get_auth_chain_ids_txn, - event_ids, - include_given, - ignore_events, - ) - - def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): - if ignore_events is None: - ignore_events = set() - - if include_given: - results = set(event_ids) - else: - results = set() - - base_sql = "SELECT auth_id FROM event_auth WHERE " - - front = set(event_ids) - while front: - new_front = set() - for chunk in batch_iter(front, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "event_id", chunk - ) - txn.execute(base_sql + clause, args) - new_front.update(r[0] for r in txn) - - new_front -= ignore_events - new_front -= results - - front = new_front - results.update(front) - - return list(results) - - def get_auth_chain_difference(self, state_sets: List[Set[str]]): - """Given sets of state events figure out the auth chain difference (as - per state res v2 algorithm). - - This equivalent to fetching the full auth chain for each set of state - and returning the events that don't appear in each and every auth - chain. - - Returns: - Deferred[Set[str]] - """ - - return self.db.runInteraction( - "get_auth_chain_difference", - self._get_auth_chain_difference_txn, - state_sets, - ) - - def _get_auth_chain_difference_txn( - self, txn, state_sets: List[Set[str]] - ) -> Set[str]: - - # Algorithm Description - # ~~~~~~~~~~~~~~~~~~~~~ - # - # The idea here is to basically walk the auth graph of each state set in - # tandem, keeping track of which auth events are reachable by each state - # set. If we reach an auth event we've already visited (via a different - # state set) then we mark that auth event and all ancestors as reachable - # by the state set. This requires that we keep track of the auth chains - # in memory. - # - # Doing it in a such a way means that we can stop early if all auth - # events we're currently walking are reachable by all state sets. - # - # *Note*: We can't stop walking an event's auth chain if it is reachable - # by all state sets. This is because other auth chains we're walking - # might be reachable only via the original auth chain. For example, - # given the following auth chain: - # - # A -> C -> D -> E - # / / - # B -´---------´ - # - # and state sets {A} and {B} then walking the auth chains of A and B - # would immediately show that C is reachable by both. However, if we - # stopped at C then we'd only reach E via the auth chain of B and so E - # would errornously get included in the returned difference. - # - # The other thing that we do is limit the number of auth chains we walk - # at once, due to practical limits (i.e. we can only query the database - # with a limited set of parameters). We pick the auth chains we walk - # each iteration based on their depth, in the hope that events with a - # lower depth are likely reachable by those with higher depths. - # - # We could use any ordering that we believe would give a rough - # topological ordering, e.g. origin server timestamp. If the ordering - # chosen is not topological then the algorithm still produces the right - # result, but perhaps a bit more inefficiently. This is why it is safe - # to use "depth" here. - - initial_events = set(state_sets[0]).union(*state_sets[1:]) - - # Dict from events in auth chains to which sets *cannot* reach them. - # I.e. if the set is empty then all sets can reach the event. - event_to_missing_sets = { - event_id: {i for i, a in enumerate(state_sets) if event_id not in a} - for event_id in initial_events - } - - # The sorted list of events whose auth chains we should walk. - search = [] # type: List[Tuple[int, str]] - - # We need to get the depth of the initial events for sorting purposes. - sql = """ - SELECT depth, event_id FROM events - WHERE %s - """ - # the list can be huge, so let's avoid looking them all up in one massive - # query. - for batch in batch_iter(initial_events, 1000): - clause, args = make_in_list_sql_clause( - txn.database_engine, "event_id", batch - ) - txn.execute(sql % (clause,), args) - - # I think building a temporary list with fetchall is more efficient than - # just `search.extend(txn)`, but this is unconfirmed - search.extend(txn.fetchall()) - - # sort by depth - search.sort() - - # Map from event to its auth events - event_to_auth_events = {} # type: Dict[str, Set[str]] - - base_sql = """ - SELECT a.event_id, auth_id, depth - FROM event_auth AS a - INNER JOIN events AS e ON (e.event_id = a.auth_id) - WHERE - """ - - while search: - # Check whether all our current walks are reachable by all state - # sets. If so we can bail. - if all(not event_to_missing_sets[eid] for _, eid in search): - break - - # Fetch the auth events and their depths of the N last events we're - # currently walking - search, chunk = search[:-100], search[-100:] - clause, args = make_in_list_sql_clause( - txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] - ) - txn.execute(base_sql + clause, args) - - for event_id, auth_event_id, auth_event_depth in txn: - event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) - - sets = event_to_missing_sets.get(auth_event_id) - if sets is None: - # First time we're seeing this event, so we add it to the - # queue of things to fetch. - search.append((auth_event_depth, auth_event_id)) - - # Assume that this event is unreachable from any of the - # state sets until proven otherwise - sets = event_to_missing_sets[auth_event_id] = set( - range(len(state_sets)) - ) - else: - # We've previously seen this event, so look up its auth - # events and recursively mark all ancestors as reachable - # by the current event's state set. - a_ids = event_to_auth_events.get(auth_event_id) - while a_ids: - new_aids = set() - for a_id in a_ids: - event_to_missing_sets[a_id].intersection_update( - event_to_missing_sets[event_id] - ) - - b = event_to_auth_events.get(a_id) - if b: - new_aids.update(b) - - a_ids = new_aids - - # Mark that the auth event is reachable by the approriate sets. - sets.intersection_update(event_to_missing_sets[event_id]) - - search.sort() - - # Return all events where not all sets can reach them. - return {eid for eid, n in event_to_missing_sets.items() if n} - - def get_oldest_events_in_room(self, room_id): - return self.db.runInteraction( - "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id - ) - - def get_oldest_events_with_depth_in_room(self, room_id): - return self.db.runInteraction( - "get_oldest_events_with_depth_in_room", - self.get_oldest_events_with_depth_in_room_txn, - room_id, - ) - - def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): - sql = ( - "SELECT b.event_id, MAX(e.depth) FROM events as e" - " INNER JOIN event_edges as g" - " ON g.event_id = e.event_id" - " INNER JOIN event_backward_extremities as b" - " ON g.prev_event_id = b.event_id" - " WHERE b.room_id = ? AND g.is_state is ?" - " GROUP BY b.event_id" - ) - - txn.execute(sql, (room_id, False)) - - return dict(txn) - - @defer.inlineCallbacks - def get_max_depth_of(self, event_ids): - """Returns the max depth of a set of event IDs - - Args: - event_ids (list[str]) - - Returns - Deferred[int] - """ - rows = yield self.db.simple_select_many_batch( - table="events", - column="event_id", - iterable=event_ids, - retcols=("depth",), - desc="get_max_depth_of", - ) - - if not rows: - return 0 - else: - return max(row["depth"] for row in rows) - - def _get_oldest_events_in_room_txn(self, txn, room_id): - return self.db.simple_select_onecol_txn( - txn, - table="event_backward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - ) - - def get_prev_events_for_room(self, room_id: str): - """ - Gets a subset of the current forward extremities in the given room. - - Limits the result to 10 extremities, so that we can avoid creating - events which refer to hundreds of prev_events. - - Args: - room_id (str): room_id - - Returns: - Deferred[List[str]]: the event ids of the forward extremites - - """ - - return self.db.runInteraction( - "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id - ) - - def _get_prev_events_for_room_txn(self, txn, room_id: str): - # we just use the 10 newest events. Older events will become - # prev_events of future events. - - sql = """ - SELECT e.event_id FROM event_forward_extremities AS f - INNER JOIN events AS e USING (event_id) - WHERE f.room_id = ? - ORDER BY e.depth DESC - LIMIT 10 - """ - - txn.execute(sql, (room_id,)) - - return [row[0] for row in txn] - - def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): - """Get the top rooms with at least N extremities. - - Args: - min_count (int): The minimum number of extremities - limit (int): The maximum number of rooms to return. - room_id_filter (iterable[str]): room_ids to exclude from the results - - Returns: - Deferred[list]: At most `limit` room IDs that have at least - `min_count` extremities, sorted by extremity count. - """ - - def _get_rooms_with_many_extremities_txn(txn): - where_clause = "1=1" - if room_id_filter: - where_clause = "room_id NOT IN (%s)" % ( - ",".join("?" for _ in room_id_filter), - ) - - sql = """ - SELECT room_id FROM event_forward_extremities - WHERE %s - GROUP BY room_id - HAVING count(*) > ? - ORDER BY count(*) DESC - LIMIT ? - """ % ( - where_clause, - ) - - query_args = list(itertools.chain(room_id_filter, [min_count, limit])) - txn.execute(sql, query_args) - return [room_id for room_id, in txn] - - return self.db.runInteraction( - "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn - ) - - @cached(max_entries=5000, iterable=True) - def get_latest_event_ids_in_room(self, room_id): - return self.db.simple_select_onecol( - table="event_forward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - desc="get_latest_event_ids_in_room", - ) - - def get_min_depth(self, room_id): - """ For hte given room, get the minimum depth we have seen for it. - """ - return self.db.runInteraction( - "get_min_depth", self._get_min_depth_interaction, room_id - ) - - def _get_min_depth_interaction(self, txn, room_id): - min_depth = self.db.simple_select_one_onecol_txn( - txn, - table="room_depth", - keyvalues={"room_id": room_id}, - retcol="min_depth", - allow_none=True, - ) - - return int(min_depth) if min_depth is not None else None - - def get_forward_extremeties_for_room(self, room_id, stream_ordering): - """For a given room_id and stream_ordering, return the forward - extremeties of the room at that point in "time". - - Throws a StoreError if we have since purged the index for - stream_orderings from that point. - - Args: - room_id (str): - stream_ordering (int): - - Returns: - deferred, which resolves to a list of event_ids - """ - # We want to make the cache more effective, so we clamp to the last - # change before the given ordering. - last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) - - # We don't always have a full stream_to_exterm_id table, e.g. after - # the upgrade that introduced it, so we make sure we never ask for a - # stream_ordering from before a restart - last_change = max(self._stream_order_on_start, last_change) - - # provided the last_change is recent enough, we now clamp the requested - # stream_ordering to it. - if last_change > self.stream_ordering_month_ago: - stream_ordering = min(last_change, stream_ordering) - - return self._get_forward_extremeties_for_room(room_id, stream_ordering) - - @cached(max_entries=5000, num_args=2) - def _get_forward_extremeties_for_room(self, room_id, stream_ordering): - """For a given room_id and stream_ordering, return the forward - extremeties of the room at that point in "time". - - Throws a StoreError if we have since purged the index for - stream_orderings from that point. - """ - - if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old") - - sql = """ - SELECT event_id FROM stream_ordering_to_exterm - INNER JOIN ( - SELECT room_id, MAX(stream_ordering) AS stream_ordering - FROM stream_ordering_to_exterm - WHERE stream_ordering <= ? GROUP BY room_id - ) AS rms USING (room_id, stream_ordering) - WHERE room_id = ? - """ - - def get_forward_extremeties_for_room_txn(txn): - txn.execute(sql, (stream_ordering, room_id)) - return [event_id for event_id, in txn] - - return self.db.runInteraction( - "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn - ) - - def get_backfill_events(self, room_id, event_list, limit): - """Get a list of Events for a given topic that occurred before (and - including) the events in event_list. Return a list of max size `limit` - - Args: - txn - room_id (str) - event_list (list) - limit (int) - """ - return ( - self.db.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) - ) - - def _get_backfill_events(self, txn, room_id, event_list, limit): - logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) - - event_results = set() - - # We want to make sure that we do a breadth-first, "depth" ordered - # search. - - query = ( - "SELECT depth, prev_event_id FROM event_edges" - " INNER JOIN events" - " ON prev_event_id = events.event_id" - " WHERE event_edges.event_id = ?" - " AND event_edges.is_state = ?" - " LIMIT ?" - ) - - queue = PriorityQueue() - - for event_id in event_list: - depth = self.db.simple_select_one_onecol_txn( - txn, - table="events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcol="depth", - allow_none=True, - ) - - if depth: - queue.put((-depth, event_id)) - - while not queue.empty() and len(event_results) < limit: - try: - _, event_id = queue.get_nowait() - except Empty: - break - - if event_id in event_results: - continue - - event_results.add(event_id) - - txn.execute(query, (event_id, False, limit - len(event_results))) - - for row in txn: - if row[1] not in event_results: - queue.put((-row[0], row[1])) - - return event_results - - @defer.inlineCallbacks - def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.db.runInteraction( - "get_missing_events", - self._get_missing_events, - room_id, - earliest_events, - latest_events, - limit, - ) - events = yield self.get_events_as_list(ids) - return events - - def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): - - seen_events = set(earliest_events) - front = set(latest_events) - seen_events - event_results = [] - - query = ( - "SELECT prev_event_id FROM event_edges " - "WHERE room_id = ? AND event_id = ? AND is_state = ? " - "LIMIT ?" - ) - - while front and len(event_results) < limit: - new_front = set() - for event_id in front: - txn.execute( - query, (room_id, event_id, False, limit - len(event_results)) - ) - - new_results = {t[0] for t in txn} - seen_events - - new_front |= new_results - seen_events |= new_results - event_results.extend(new_results) - - front = new_front - - # we built the list working backwards from latest_events; we now need to - # reverse it so that the events are approximately chronological. - event_results.reverse() - return event_results - - @defer.inlineCallbacks - def get_successor_events(self, event_ids): - """Fetch all events that have the given events as a prev event - - Args: - event_ids (iterable[str]) - - Returns: - Deferred[list[str]] - """ - rows = yield self.db.simple_select_many_batch( - table="event_edges", - column="prev_event_id", - iterable=event_ids, - retcols=("event_id",), - desc="get_successor_events", - ) - - return [row["event_id"] for row in rows] - - -class EventFederationStore(EventFederationWorkerStore): - """ Responsible for storing and serving up the various graphs associated - with an event. Including the main event graph and the auth chains for an - event. - - Also has methods for getting the front (latest) and back (oldest) edges - of the event graphs. These are used to generate the parents for new events - and backfilling from another server respectively. - """ - - EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - - def __init__(self, database: Database, db_conn, hs): - super(EventFederationStore, self).__init__(database, db_conn, hs) - - self.db.updates.register_background_update_handler( - self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth - ) - - hs.get_clock().looping_call( - self._delete_old_forward_extrem_cache, 60 * 60 * 1000 - ) - - def _delete_old_forward_extrem_cache(self): - def _delete_old_forward_extrem_cache_txn(txn): - # Delete entries older than a month, while making sure we don't delete - # the only entries for a room. - sql = """ - DELETE FROM stream_ordering_to_exterm - WHERE - room_id IN ( - SELECT room_id - FROM stream_ordering_to_exterm - WHERE stream_ordering > ? - ) AND stream_ordering < ? - """ - txn.execute( - sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) - ) - - return run_as_background_process( - "delete_old_forward_extrem_cache", - self.db.runInteraction, - "_delete_old_forward_extrem_cache", - _delete_old_forward_extrem_cache_txn, - ) - - def clean_room_for_join(self, room_id): - return self.db.runInteraction( - "clean_room_for_join", self._clean_room_for_join_txn, room_id - ) - - def _clean_room_for_join_txn(self, txn, room_id): - query = "DELETE FROM event_forward_extremities WHERE room_id = ?" - - txn.execute(query, (room_id,)) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - - @defer.inlineCallbacks - def _background_delete_non_state_event_auth(self, progress, batch_size): - def delete_event_auth(txn): - target_min_stream_id = progress.get("target_min_stream_id_inclusive") - max_stream_id = progress.get("max_stream_id_exclusive") - - if not target_min_stream_id or not max_stream_id: - txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events") - rows = txn.fetchall() - target_min_stream_id = rows[0][0] - - txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events") - rows = txn.fetchall() - max_stream_id = rows[0][0] - - min_stream_id = max_stream_id - batch_size - - sql = """ - DELETE FROM event_auth - WHERE event_id IN ( - SELECT event_id FROM events - LEFT JOIN state_events USING (room_id, event_id) - WHERE ? <= stream_ordering AND stream_ordering < ? - AND state_key IS null - ) - """ - - txn.execute(sql, (min_stream_id, max_stream_id)) - - new_progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - } - - self.db.updates._background_update_progress_txn( - txn, self.EVENT_AUTH_STATE_ONLY, new_progress - ) - - return min_stream_id >= target_min_stream_id - - result = yield self.db.runInteraction( - self.EVENT_AUTH_STATE_ONLY, delete_event_auth - ) - - if not result: - yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) - - return batch_size diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py deleted file mode 100644 index ad82838901..0000000000 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ /dev/null @@ -1,883 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import List - -from canonicaljson import json - -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json -from synapse.storage.database import Database -from synapse.util.caches.descriptors import cachedInlineCallbacks - -logger = logging.getLogger(__name__) - - -DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] -DEFAULT_HIGHLIGHT_ACTION = [ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, -] - - -def _serialize_action(actions, is_highlight): - """Custom serializer for actions. This allows us to "compress" common actions. - - We use the fact that most users have the same actions for notifs (and for - highlights). - We store these default actions as the empty string rather than the full JSON. - Since the empty string isn't valid JSON there is no risk of this clashing with - any real JSON actions - """ - if is_highlight: - if actions == DEFAULT_HIGHLIGHT_ACTION: - return "" # We use empty string as the column is non-NULL - else: - if actions == DEFAULT_NOTIF_ACTION: - return "" - return json.dumps(actions) - - -def _deserialize_action(actions, is_highlight): - """Custom deserializer for actions. This allows us to "compress" common actions - """ - if actions: - return db_to_json(actions) - - if is_highlight: - return DEFAULT_HIGHLIGHT_ACTION - else: - return DEFAULT_NOTIF_ACTION - - -class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) - - # These get correctly set by _find_stream_orderings_for_times_txn - self.stream_ordering_month_ago = None - self.stream_ordering_day_ago = None - - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - ) - self._find_stream_orderings_for_times_txn(cur) - cur.close() - - self.find_stream_orderings_looping_call = self._clock.looping_call( - self._find_stream_orderings_for_times, 10 * 60 * 1000 - ) - self._rotate_delay = 3 - self._rotate_count = 10000 - - @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) - def get_unread_event_push_actions_by_room_for_user( - self, room_id, user_id, last_read_event_id - ): - ret = yield self.db.runInteraction( - "get_unread_event_push_actions_by_room", - self._get_unread_counts_by_receipt_txn, - room_id, - user_id, - last_read_event_id, - ) - return ret - - def _get_unread_counts_by_receipt_txn( - self, txn, room_id, user_id, last_read_event_id - ): - sql = ( - "SELECT stream_ordering" - " FROM events" - " WHERE room_id = ? AND event_id = ?" - ) - txn.execute(sql, (room_id, last_read_event_id)) - results = txn.fetchall() - if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} - - stream_ordering = results[0][0] - - return self._get_unread_counts_by_pos_txn( - txn, room_id, user_id, stream_ordering - ) - - def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): - - # First get number of notifications. - # We don't need to put a notif=1 clause as all rows always have - # notif=1 - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" - ) - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - notify_count = row[0] if row else 0 - - txn.execute( - """ - SELECT notif_count FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering > ? - """, - (room_id, user_id, stream_ordering), - ) - rows = txn.fetchall() - if rows: - notify_count += rows[0][0] - - # Now get the number of highlights - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " highlight = 1" - " AND user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" - ) - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - highlight_count = row[0] if row else 0 - - return {"notify_count": notify_count, "highlight_count": highlight_count} - - async def get_push_action_users_in_range( - self, min_stream_ordering, max_stream_ordering - ): - def f(txn): - sql = ( - "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" - " stream_ordering >= ? AND stream_ordering <= ?" - ) - txn.execute(sql, (min_stream_ordering, max_stream_ordering)) - return [r[0] for r in txn] - - ret = await self.db.runInteraction("get_push_action_users_in_range", f) - return ret - - async def get_unread_push_actions_for_user_in_range_for_http( - self, - user_id: str, - min_stream_ordering: int, - max_stream_ordering: int, - limit: int = 20, - ) -> List[dict]: - """Get a list of the most recent unread push actions for a given user, - within the given stream ordering range. Called by the httppusher. - - Args: - user_id: The user to fetch push actions for. - min_stream_ordering: The exclusive lower bound on the - stream ordering of event push actions to fetch. - max_stream_ordering: The inclusive upper bound on the - stream ordering of event push actions to fetch. - limit: The maximum number of rows to return. - Returns: - A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions". - The list will be ordered by ascending stream_ordering. - The list will have between 0~limit entries. - """ - # find rooms that have a read receipt in them and return the next - # push actions - def get_after_receipt(txn): - # find rooms that have a read receipt in them and return the next - # push actions - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight " - " FROM (" - " SELECT room_id," - " MAX(stream_ordering) as stream_ordering" - " FROM events" - " INNER JOIN receipts_linearized USING (room_id, event_id)" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - ") AS rl," - " event_push_actions AS ep" - " WHERE" - " ep.room_id = rl.room_id" - " AND ep.stream_ordering > rl.stream_ordering" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " ORDER BY ep.stream_ordering ASC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) - return txn.fetchall() - - after_read_receipt = await self.db.runInteraction( - "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt - ) - - # There are rooms with push actions in them but you don't have a read receipt in - # them e.g. rooms you've been invited to, so get push actions for rooms which do - # not have read receipts in them too. - def get_no_receipt(txn): - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight " - " FROM event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id NOT IN (" - " SELECT room_id FROM receipts_linearized" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - " )" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " ORDER BY ep.stream_ordering ASC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) - return txn.fetchall() - - no_read_receipt = await self.db.runInteraction( - "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt - ) - - notifs = [ - { - "event_id": row[0], - "room_id": row[1], - "stream_ordering": row[2], - "actions": _deserialize_action(row[3], row[4]), - } - for row in after_read_receipt + no_read_receipt - ] - - # Now sort it so it's ordered correctly, since currently it will - # contain results from the first query, correctly ordered, followed - # by results from the second query, but we want them all ordered - # by stream_ordering, oldest first. - notifs.sort(key=lambda r: r["stream_ordering"]) - - # Take only up to the limit. We have to stop at the limit because - # one of the subqueries may have hit the limit. - return notifs[:limit] - - async def get_unread_push_actions_for_user_in_range_for_email( - self, - user_id: str, - min_stream_ordering: int, - max_stream_ordering: int, - limit: int = 20, - ) -> List[dict]: - """Get a list of the most recent unread push actions for a given user, - within the given stream ordering range. Called by the emailpusher - - Args: - user_id: The user to fetch push actions for. - min_stream_ordering: The exclusive lower bound on the - stream ordering of event push actions to fetch. - max_stream_ordering: The inclusive upper bound on the - stream ordering of event push actions to fetch. - limit: The maximum number of rows to return. - Returns: - A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". - The list will be ordered by descending received_ts. - The list will have between 0~limit entries. - """ - # find rooms that have a read receipt in them and return the most recent - # push actions - def get_after_receipt(txn): - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight, e.received_ts" - " FROM (" - " SELECT room_id," - " MAX(stream_ordering) as stream_ordering" - " FROM events" - " INNER JOIN receipts_linearized USING (room_id, event_id)" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - ") AS rl," - " event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id = rl.room_id" - " AND ep.stream_ordering > rl.stream_ordering" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " ORDER BY ep.stream_ordering DESC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) - return txn.fetchall() - - after_read_receipt = await self.db.runInteraction( - "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt - ) - - # There are rooms with push actions in them but you don't have a read receipt in - # them e.g. rooms you've been invited to, so get push actions for rooms which do - # not have read receipts in them too. - def get_no_receipt(txn): - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight, e.received_ts" - " FROM event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id NOT IN (" - " SELECT room_id FROM receipts_linearized" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - " )" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " ORDER BY ep.stream_ordering DESC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) - return txn.fetchall() - - no_read_receipt = await self.db.runInteraction( - "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt - ) - - # Make a list of dicts from the two sets of results. - notifs = [ - { - "event_id": row[0], - "room_id": row[1], - "stream_ordering": row[2], - "actions": _deserialize_action(row[3], row[4]), - "received_ts": row[5], - } - for row in after_read_receipt + no_read_receipt - ] - - # Now sort it so it's ordered correctly, since currently it will - # contain results from the first query, correctly ordered, followed - # by results from the second query, but we want them all ordered - # by received_ts (most recent first) - notifs.sort(key=lambda r: -(r["received_ts"] or 0)) - - # Now return the first `limit` - return notifs[:limit] - - def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): - """A fast check to see if there might be something to push for the - user since the given stream ordering. May return false positives. - - Useful to know whether to bother starting a pusher on start up or not. - - Args: - user_id (str) - min_stream_ordering (int) - - Returns: - Deferred[bool]: True if there may be push to process, False if - there definitely isn't. - """ - - def _get_if_maybe_push_in_range_for_user_txn(txn): - sql = """ - SELECT 1 FROM event_push_actions - WHERE user_id = ? AND stream_ordering > ? - LIMIT 1 - """ - - txn.execute(sql, (user_id, min_stream_ordering)) - return bool(txn.fetchone()) - - return self.db.runInteraction( - "get_if_maybe_push_in_range_for_user", - _get_if_maybe_push_in_range_for_user_txn, - ) - - async def add_push_actions_to_staging(self, event_id, user_id_actions): - """Add the push actions for the event to the push action staging area. - - Args: - event_id (str) - user_id_actions (dict[str, list[dict|str])]): A dictionary mapping - user_id to list of push actions, where an action can either be - a string or dict. - - Returns: - Deferred - """ - - if not user_id_actions: - return - - # This is a helper function for generating the necessary tuple that - # can be used to inert into the `event_push_actions_staging` table. - def _gen_entry(user_id, actions): - is_highlight = 1 if _action_has_highlight(actions) else 0 - return ( - event_id, # event_id column - user_id, # user_id column - _serialize_action(actions, is_highlight), # actions column - 1, # notif column - is_highlight, # highlight column - ) - - def _add_push_actions_to_staging_txn(txn): - # We don't use simple_insert_many here to avoid the overhead - # of generating lists of dicts. - - sql = """ - INSERT INTO event_push_actions_staging - (event_id, user_id, actions, notif, highlight) - VALUES (?, ?, ?, ?, ?) - """ - - txn.executemany( - sql, - ( - _gen_entry(user_id, actions) - for user_id, actions in user_id_actions.items() - ), - ) - - return await self.db.runInteraction( - "add_push_actions_to_staging", _add_push_actions_to_staging_txn - ) - - async def remove_push_actions_from_staging(self, event_id: str) -> None: - """Called if we failed to persist the event to ensure that stale push - actions don't build up in the DB - """ - - try: - res = await self.db.simple_delete( - table="event_push_actions_staging", - keyvalues={"event_id": event_id}, - desc="remove_push_actions_from_staging", - ) - return res - except Exception: - # this method is called from an exception handler, so propagating - # another exception here really isn't helpful - there's nothing - # the caller can do about it. Just log the exception and move on. - logger.exception( - "Error removing push actions after event persistence failure" - ) - - def _find_stream_orderings_for_times(self): - return run_as_background_process( - "event_push_action_stream_orderings", - self.db.runInteraction, - "_find_stream_orderings_for_times", - self._find_stream_orderings_for_times_txn, - ) - - def _find_stream_orderings_for_times_txn(self, txn): - logger.info("Searching for stream ordering 1 month ago") - self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 - ) - logger.info( - "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago - ) - logger.info("Searching for stream ordering 1 day ago") - self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 - ) - logger.info( - "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago - ) - - def find_first_stream_ordering_after_ts(self, ts): - """Gets the stream ordering corresponding to a given timestamp. - - Specifically, finds the stream_ordering of the first event that was - received on or after the timestamp. This is done by a binary search on - the events table, since there is no index on received_ts, so is - relatively slow. - - Args: - ts (int): timestamp in millis - - Returns: - Deferred[int]: stream ordering of the first event received on/after - the timestamp - """ - return self.db.runInteraction( - "_find_first_stream_ordering_after_ts_txn", - self._find_first_stream_ordering_after_ts_txn, - ts, - ) - - @staticmethod - def _find_first_stream_ordering_after_ts_txn(txn, ts): - """ - Find the stream_ordering of the first event that was received on or - after a given timestamp. This is relatively slow as there is no index - on received_ts but we can then use this to delete push actions before - this. - - received_ts must necessarily be in the same order as stream_ordering - and stream_ordering is indexed, so we manually binary search using - stream_ordering - - Args: - txn (twisted.enterprise.adbapi.Transaction): - ts (int): timestamp to search for - - Returns: - int: stream ordering - """ - txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = txn.fetchone()[0] - - if max_stream_ordering is None: - return 0 - - # We want the first stream_ordering in which received_ts is greater - # than or equal to ts. Call this point X. - # - # We maintain the invariants: - # - # range_start <= X <= range_end - # - range_start = 0 - range_end = max_stream_ordering + 1 - - # Given a stream_ordering, look up the timestamp at that - # stream_ordering. - # - # The array may be sparse (we may be missing some stream_orderings). - # We treat the gaps as the same as having the same value as the - # preceding entry, because we will pick the lowest stream_ordering - # which satisfies our requirement of received_ts >= ts. - # - # For example, if our array of events indexed by stream_ordering is - # [10, , 20], we should treat this as being equivalent to - # [10, 10, 20]. - # - sql = ( - "SELECT received_ts FROM events" - " WHERE stream_ordering <= ?" - " ORDER BY stream_ordering DESC" - " LIMIT 1" - ) - - while range_end - range_start > 0: - middle = (range_end + range_start) // 2 - txn.execute(sql, (middle,)) - row = txn.fetchone() - if row is None: - # no rows with stream_ordering<=middle - range_start = middle + 1 - continue - - middle_ts = row[0] - if ts > middle_ts: - # we got a timestamp lower than the one we were looking for. - # definitely need to look higher: X > middle. - range_start = middle + 1 - else: - # we got a timestamp higher than (or the same as) the one we - # were looking for. We aren't yet sure about the point we - # looked up, but we can be sure that X <= middle. - range_end = middle - - return range_end - - async def get_time_of_last_push_action_before(self, stream_ordering): - def f(txn): - sql = ( - "SELECT e.received_ts" - " FROM event_push_actions AS ep" - " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" - " WHERE ep.stream_ordering > ?" - " ORDER BY ep.stream_ordering ASC" - " LIMIT 1" - ) - txn.execute(sql, (stream_ordering,)) - return txn.fetchone() - - result = await self.db.runInteraction("get_time_of_last_push_action_before", f) - return result[0] if result else None - - -class EventPushActionsStore(EventPushActionsWorkerStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - - def __init__(self, database: Database, db_conn, hs): - super(EventPushActionsStore, self).__init__(database, db_conn, hs) - - self.db.updates.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) - - self.db.updates.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1", - ) - - self._doing_notif_rotation = False - self._rotate_notif_loop = self._clock.looping_call( - self._start_rotate_notifs, 30 * 60 * 1000 - ) - - async def get_push_actions_for_user( - self, user_id, before=None, limit=50, only_highlight=False - ): - def f(txn): - before_clause = "" - if before: - before_clause = "AND epa.stream_ordering < ?" - args = [user_id, before, limit] - else: - args = [user_id, limit] - - if only_highlight: - if len(before_clause) > 0: - before_clause += " " - before_clause += "AND epa.highlight = 1" - - # NB. This assumes event_ids are globally unique since - # it makes the query easier to index - sql = ( - "SELECT epa.event_id, epa.room_id," - " epa.stream_ordering, epa.topological_ordering," - " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" - " FROM event_push_actions epa, events e" - " WHERE epa.event_id = e.event_id" - " AND epa.user_id = ? %s" - " ORDER BY epa.stream_ordering DESC" - " LIMIT ?" % (before_clause,) - ) - txn.execute(sql, args) - return self.db.cursor_to_dict(txn) - - push_actions = await self.db.runInteraction("get_push_actions_for_user", f) - for pa in push_actions: - pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) - return push_actions - - async def get_latest_push_action_stream_ordering(self): - def f(txn): - txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") - return txn.fetchone() - - result = await self.db.runInteraction( - "get_latest_push_action_stream_ordering", f - ) - return result[0] or 0 - - def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): - """ - Purges old push actions for a user and room before a given - stream_ordering. - - We however keep a months worth of highlighted notifications, so that - users can still get a list of recent highlights. - - Args: - txn: The transcation - room_id: Room ID to delete from - user_id: user ID to delete for - stream_ordering: The lowest stream ordering which will - not be deleted. - """ - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id, user_id), - ) - - # We need to join on the events table to get the received_ts for - # event_push_actions and sqlite won't let us use a join in a delete so - # we can't just delete where received_ts < x. Furthermore we can - # only identify event_push_actions by a tuple of room_id, event_id - # we we can't use a subquery. - # Instead, we look up the stream ordering for the last event in that - # room received before the threshold time and delete event_push_actions - # in the room with a stream_odering before that. - txn.execute( - "DELETE FROM event_push_actions " - " WHERE user_id = ? AND room_id = ? AND " - " stream_ordering <= ?" - " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), - ) - - txn.execute( - """ - DELETE FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, - (room_id, user_id, stream_ordering), - ) - - def _start_rotate_notifs(self): - return run_as_background_process("rotate_notifs", self._rotate_notifs) - - async def _rotate_notifs(self): - if self._doing_notif_rotation or self.stream_ordering_day_ago is None: - return - self._doing_notif_rotation = True - - try: - while True: - logger.info("Rotating notifications") - - caught_up = await self.db.runInteraction( - "_rotate_notifs", self._rotate_notifs_txn - ) - if caught_up: - break - await self.hs.get_clock().sleep(self._rotate_delay) - finally: - self._doing_notif_rotation = False - - def _rotate_notifs_txn(self, txn): - """Archives older notifications into event_push_summary. Returns whether - the archiving process has caught up or not. - """ - - old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( - txn, - table="event_push_summary_stream_ordering", - keyvalues={}, - retcol="stream_ordering", - ) - - # We don't to try and rotate millions of rows at once, so we cap the - # maximum stream ordering we'll rotate before. - txn.execute( - """ - SELECT stream_ordering FROM event_push_actions - WHERE stream_ordering > ? - ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? - """, - (old_rotate_stream_ordering, self._rotate_count), - ) - stream_row = txn.fetchone() - if stream_row: - (offset_stream_ordering,) = stream_row - rotate_to_stream_ordering = min( - self.stream_ordering_day_ago, offset_stream_ordering - ) - caught_up = offset_stream_ordering >= self.stream_ordering_day_ago - else: - rotate_to_stream_ordering = self.stream_ordering_day_ago - caught_up = True - - logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) - - self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) - - # We have caught up iff we were limited by `stream_ordering_day_ago` - return caught_up - - def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( - txn, - table="event_push_summary_stream_ordering", - keyvalues={}, - retcol="stream_ordering", - ) - - # Calculate the new counts that should be upserted into event_push_summary - sql = """ - SELECT user_id, room_id, - coalesce(old.notif_count, 0) + upd.notif_count, - upd.stream_ordering, - old.user_id - FROM ( - SELECT user_id, room_id, count(*) as notif_count, - max(stream_ordering) as stream_ordering - FROM event_push_actions - WHERE ? <= stream_ordering AND stream_ordering < ? - AND highlight = 0 - GROUP BY user_id, room_id - ) AS upd - LEFT JOIN event_push_summary AS old USING (user_id, room_id) - """ - - txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) - rows = txn.fetchall() - - logger.info("Rotating notifications, handling %d rows", len(rows)) - - # If the `old.user_id` above is NULL then we know there isn't already an - # entry in the table, so we simply insert it. Otherwise we update the - # existing table. - self.db.simple_insert_many_txn( - txn, - table="event_push_summary", - values=[ - { - "user_id": row[0], - "room_id": row[1], - "notif_count": row[2], - "stream_ordering": row[3], - } - for row in rows - if row[4] is None - ], - ) - - txn.executemany( - """ - UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? - WHERE user_id = ? AND room_id = ? - """, - ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), - ) - - txn.execute( - "DELETE FROM event_push_actions" - " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", - (old_rotate_stream_ordering, rotate_to_stream_ordering), - ) - - logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) - - txn.execute( - "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", - (rotate_to_stream_ordering,), - ) - - -def _action_has_highlight(actions): - for action in actions: - try: - if action.get("set_tweak", None) == "highlight": - return action.get("value", True) - except AttributeError: - pass - - return False diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py deleted file mode 100644 index 0c9c02afa1..0000000000 --- a/synapse/storage/data_stores/main/events.py +++ /dev/null @@ -1,1521 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import itertools -import logging -from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple - -import attr -from prometheus_client import Counter - -from twisted.internet import defer - -import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes -from synapse.api.room_versions import RoomVersions -from synapse.crypto.event_signing import compute_event_reference_hash -from synapse.events import EventBase # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 -from synapse.logging.utils import log_function -from synapse.storage._base import db_to_json, make_in_list_sql_clause -from synapse.storage.data_stores.main.search import SearchEntry -from synapse.storage.database import Database, LoggingTransaction -from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import StateMap, get_domain_from_id -from synapse.util.frozenutils import frozendict_json_encoder -from synapse.util.iterutils import batch_iter - -if TYPE_CHECKING: - from synapse.server import HomeServer - from synapse.storage.data_stores.main import DataStore - - -logger = logging.getLogger(__name__) - -persist_event_counter = Counter("synapse_storage_events_persisted_events", "") -event_counter = Counter( - "synapse_storage_events_persisted_events_sep", - "", - ["type", "origin_type", "origin_entity"], -) - -STATE_EVENT_TYPES_TO_MARK_UNREAD = { - EventTypes.Topic, - EventTypes.Name, - EventTypes.RoomAvatar, - EventTypes.Tombstone, -} - - -def should_count_as_unread(event: EventBase, context: EventContext) -> bool: - # Exclude rejected and soft-failed events. - if context.rejected or event.internal_metadata.is_soft_failed(): - return False - - # Exclude notices. - if ( - not event.is_state() - and event.type == EventTypes.Message - and event.content.get("msgtype") == "m.notice" - ): - return False - - # Exclude edits. - relates_to = event.content.get("m.relates_to", {}) - if relates_to.get("rel_type") == RelationTypes.REPLACE: - return False - - # Mark events that have a non-empty string body as unread. - body = event.content.get("body") - if isinstance(body, str) and body: - return True - - # Mark some state events as unread. - if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: - return True - - # Mark encrypted events as unread. - if not event.is_state() and event.type == EventTypes.Encrypted: - return True - - return False - - -def encode_json(json_object): - """ - Encode a Python object as JSON and return it in a Unicode string. - """ - out = frozendict_json_encoder.encode(json_object) - if isinstance(out, bytes): - out = out.decode("utf8") - return out - - -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) - - -@attr.s(slots=True) -class DeltaState: - """Deltas to use to update the `current_state_events` table. - - Attributes: - to_delete: List of type/state_keys to delete from current state - to_insert: Map of state to upsert into current state - no_longer_in_room: The server is not longer in the room, so the room - should e.g. be removed from `current_state_events` table. - """ - - to_delete = attr.ib(type=List[Tuple[str, str]]) - to_insert = attr.ib(type=StateMap[str]) - no_longer_in_room = attr.ib(type=bool, default=False) - - -class PersistEventsStore: - """Contains all the functions for writing events to the database. - - Should only be instantiated on one process (when using a worker mode setup). - - Note: This is not part of the `DataStore` mixin. - """ - - def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"): - self.hs = hs - self.db = db - self.store = main_data_store - self.database_engine = db.engine - self._clock = hs.get_clock() - - self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - self.is_mine_id = hs.is_mine_id - - # Ideally we'd move these ID gens here, unfortunately some other ID - # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator - self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator - - # This should only exist on instances that are configured to write - assert ( - hs.config.worker.writers.events == hs.get_instance_name() - ), "Can only instantiate EventsStore on master" - - @defer.inlineCallbacks - def _persist_events_and_state_updates( - self, - events_and_contexts: List[Tuple[EventBase, EventContext]], - current_state_for_room: Dict[str, StateMap[str]], - state_delta_for_room: Dict[str, DeltaState], - new_forward_extremeties: Dict[str, List[str]], - backfilled: bool = False, - ): - """Persist a set of events alongside updates to the current state and - forward extremities tables. - - Args: - events_and_contexts: - current_state_for_room: Map from room_id to the current state of - the room based on forward extremities - state_delta_for_room: Map from room_id to the delta to apply to - room state - new_forward_extremities: Map from room_id to list of event IDs - that are the new forward extremities of the room. - backfilled - - Returns: - Deferred: resolves when the events have been persisted - """ - - # We want to calculate the stream orderings as late as possible, as - # we only notify after all events with a lesser stream ordering have - # been persisted. I.e. if we spend 10s inside the with block then - # that will delay all subsequent events from being notified about. - # Hence why we do it down here rather than wrapping the entire - # function. - # - # Its safe to do this after calculating the state deltas etc as we - # only need to protect the *persistence* of the events. This is to - # ensure that queries of the form "fetch events since X" don't - # return events and stream positions after events that are still in - # flight, as otherwise subsequent requests "fetch event since Y" - # will not return those events. - # - # Note: Multiple instances of this function cannot be in flight at - # the same time for the same room. - if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( - len(events_and_contexts) - ) - else: - stream_ordering_manager = self._stream_id_gen.get_next_mult( - len(events_and_contexts) - ) - - with stream_ordering_manager as stream_orderings: - for (event, context), stream in zip(events_and_contexts, stream_orderings): - event.internal_metadata.stream_ordering = stream - - yield self.db.runInteraction( - "persist_events", - self._persist_events_txn, - events_and_contexts=events_and_contexts, - backfilled=backfilled, - state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, - ) - persist_event_counter.inc(len(events_and_contexts)) - - if not backfilled: - # backfilled events have negative stream orderings, so we don't - # want to set the event_persisted_position to that. - synapse.metrics.event_persisted_position.set( - events_and_contexts[-1][0].internal_metadata.stream_ordering - ) - - for event, context in events_and_contexts: - if context.app_service: - origin_type = "local" - origin_entity = context.app_service.id - elif self.hs.is_mine_id(event.sender): - origin_type = "local" - origin_entity = "*client*" - else: - origin_type = "remote" - origin_entity = get_domain_from_id(event.sender) - - event_counter.labels(event.type, origin_type, origin_entity).inc() - - self.store.get_unread_message_count_for_user.invalidate_many( - (event.room_id,), - ) - - for room_id, new_state in current_state_for_room.items(): - self.store.get_current_state_ids.prefill((room_id,), new_state) - - for room_id, latest_event_ids in new_forward_extremeties.items(): - self.store.get_latest_event_ids_in_room.prefill( - (room_id,), list(latest_event_ids) - ) - - @defer.inlineCallbacks - def _get_events_which_are_prevs(self, event_ids): - """Filter the supplied list of event_ids to get those which are prev_events of - existing (non-outlier/rejected) events. - - Args: - event_ids (Iterable[str]): event ids to filter - - Returns: - Deferred[List[str]]: filtered event ids - """ - results = [] - - def _get_events_which_are_prevs_txn(txn, batch): - sql = """ - SELECT prev_event_id, internal_metadata - FROM event_edges - INNER JOIN events USING (event_id) - LEFT JOIN rejections USING (event_id) - LEFT JOIN event_json USING (event_id) - WHERE - NOT events.outlier - AND rejections.event_id IS NULL - AND - """ - - clause, args = make_in_list_sql_clause( - self.database_engine, "prev_event_id", batch - ) - - txn.execute(sql + clause, args) - results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) - - for chunk in batch_iter(event_ids, 100): - yield self.db.runInteraction( - "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk - ) - - return results - - @defer.inlineCallbacks - def _get_prevs_before_rejected(self, event_ids): - """Get soft-failed ancestors to remove from the extremities. - - Given a set of events, find all those that have been soft-failed or - rejected. Returns those soft failed/rejected events and their prev - events (whether soft-failed/rejected or not), and recurses up the - prev-event graph until it finds no more soft-failed/rejected events. - - This is used to find extremities that are ancestors of new events, but - are separated by soft failed events. - - Args: - event_ids (Iterable[str]): Events to find prev events for. Note - that these must have already been persisted. - - Returns: - Deferred[set[str]] - """ - - # The set of event_ids to return. This includes all soft-failed events - # and their prev events. - existing_prevs = set() - - def _get_prevs_before_rejected_txn(txn, batch): - to_recursively_check = batch - - while to_recursively_check: - sql = """ - SELECT - event_id, prev_event_id, internal_metadata, - rejections.event_id IS NOT NULL - FROM event_edges - INNER JOIN events USING (event_id) - LEFT JOIN rejections USING (event_id) - LEFT JOIN event_json USING (event_id) - WHERE - NOT events.outlier - AND - """ - - clause, args = make_in_list_sql_clause( - self.database_engine, "event_id", to_recursively_check - ) - - txn.execute(sql + clause, args) - to_recursively_check = [] - - for event_id, prev_event_id, metadata, rejected in txn: - if prev_event_id in existing_prevs: - continue - - soft_failed = db_to_json(metadata).get("soft_failed") - if soft_failed or rejected: - to_recursively_check.append(prev_event_id) - existing_prevs.add(prev_event_id) - - for chunk in batch_iter(event_ids, 100): - yield self.db.runInteraction( - "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk - ) - - return existing_prevs - - @log_function - def _persist_events_txn( - self, - txn: LoggingTransaction, - events_and_contexts: List[Tuple[EventBase, EventContext]], - backfilled: bool, - state_delta_for_room: Dict[str, DeltaState] = {}, - new_forward_extremeties: Dict[str, List[str]] = {}, - ): - """Insert some number of room events into the necessary database tables. - - Rejected events are only inserted into the events table, the events_json table, - and the rejections table. Things reading from those table will need to check - whether the event was rejected. - - Args: - txn - events_and_contexts: events to persist - backfilled: True if the events were backfilled - delete_existing True to purge existing table rows for the events - from the database. This is useful when retrying due to - IntegrityError. - state_delta_for_room: The current-state delta for each room. - new_forward_extremetie: The new forward extremities for each room. - For each room, a list of the event ids which are the forward - extremities. - - """ - all_events_and_contexts = events_and_contexts - - min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering - max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - - self._update_forward_extremities_txn( - txn, - new_forward_extremities=new_forward_extremeties, - max_stream_order=max_stream_order, - ) - - # Ensure that we don't have the same event twice. - events_and_contexts = self._filter_events_and_contexts_for_duplicates( - events_and_contexts - ) - - self._update_room_depths_txn( - txn, events_and_contexts=events_and_contexts, backfilled=backfilled - ) - - # _update_outliers_txn filters out any events which have already been - # persisted, and returns the filtered list. - events_and_contexts = self._update_outliers_txn( - txn, events_and_contexts=events_and_contexts - ) - - # From this point onwards the events are only events that we haven't - # seen before. - - self._store_event_txn(txn, events_and_contexts=events_and_contexts) - - # Insert into event_to_state_groups. - self._store_event_state_mappings_txn(txn, events_and_contexts) - - # We want to store event_auth mappings for rejected events, as they're - # used in state res v2. - # This is only necessary if the rejected event appears in an accepted - # event's auth chain, but its easier for now just to store them (and - # it doesn't take much storage compared to storing the entire event - # anyway). - self.db.simple_insert_many_txn( - txn, - table="event_auth", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "auth_id": auth_id, - } - for event, _ in events_and_contexts - for auth_id in event.auth_event_ids() - if event.is_state() - ], - ) - - # _store_rejected_events_txn filters out any events which were - # rejected, and returns the filtered list. - events_and_contexts = self._store_rejected_events_txn( - txn, events_and_contexts=events_and_contexts - ) - - # From this point onwards the events are only ones that weren't - # rejected. - - self._update_metadata_tables_txn( - txn, - events_and_contexts=events_and_contexts, - all_events_and_contexts=all_events_and_contexts, - backfilled=backfilled, - ) - - # We call this last as it assumes we've inserted the events into - # room_memberships, where applicable. - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) - - def _update_current_state_txn( - self, - txn: LoggingTransaction, - state_delta_by_room: Dict[str, DeltaState], - stream_id: int, - ): - for room_id, delta_state in state_delta_by_room.items(): - to_delete = delta_state.to_delete - to_insert = delta_state.to_insert - - if delta_state.no_longer_in_room: - # Server is no longer in the room so we delete the room from - # current_state_events, being careful we've already updated the - # rooms.room_version column (which gets populated in a - # background task). - self._upsert_room_version_txn(txn, room_id) - - # Before deleting we populate the current_state_delta_stream - # so that async background tasks get told what happened. - sql = """ - INSERT INTO current_state_delta_stream - (stream_id, room_id, type, state_key, event_id, prev_event_id) - SELECT ?, room_id, type, state_key, null, event_id - FROM current_state_events - WHERE room_id = ? - """ - txn.execute(sql, (stream_id, room_id)) - - self.db.simple_delete_txn( - txn, table="current_state_events", keyvalues={"room_id": room_id}, - ) - else: - # We're still in the room, so we update the current state as normal. - - # First we add entries to the current_state_delta_stream. We - # do this before updating the current_state_events table so - # that we can use it to calculate the `prev_event_id`. (This - # allows us to not have to pull out the existing state - # unnecessarily). - # - # The stream_id for the update is chosen to be the minimum of the stream_ids - # for the batch of the events that we are persisting; that means we do not - # end up in a situation where workers see events before the - # current_state_delta updates. - # - sql = """ - INSERT INTO current_state_delta_stream - (stream_id, room_id, type, state_key, event_id, prev_event_id) - SELECT ?, ?, ?, ?, ?, ( - SELECT event_id FROM current_state_events - WHERE room_id = ? AND type = ? AND state_key = ? - ) - """ - txn.executemany( - sql, - ( - ( - stream_id, - room_id, - etype, - state_key, - to_insert.get((etype, state_key)), - room_id, - etype, - state_key, - ) - for etype, state_key in itertools.chain(to_delete, to_insert) - ), - ) - # Now we actually update the current_state_events table - - txn.executemany( - "DELETE FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?", - ( - (room_id, etype, state_key) - for etype, state_key in itertools.chain(to_delete, to_insert) - ), - ) - - # We include the membership in the current state table, hence we do - # a lookup when we insert. This assumes that all events have already - # been inserted into room_memberships. - txn.executemany( - """INSERT INTO current_state_events - (room_id, type, state_key, event_id, membership) - VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) - """, - [ - (room_id, key[0], key[1], ev_id, ev_id) - for key, ev_id in to_insert.items() - ], - ) - - # We now update `local_current_membership`. We do this regardless - # of whether we're still in the room or not to handle the case where - # e.g. we just got banned (where we need to record that fact here). - - # Note: Do we really want to delete rows here (that we do not - # subsequently reinsert below)? While technically correct it means - # we have no record of the fact the user *was* a member of the - # room but got, say, state reset out of it. - if to_delete or to_insert: - txn.executemany( - "DELETE FROM local_current_membership" - " WHERE room_id = ? AND user_id = ?", - ( - (room_id, state_key) - for etype, state_key in itertools.chain(to_delete, to_insert) - if etype == EventTypes.Member and self.is_mine_id(state_key) - ), - ) - - if to_insert: - txn.executemany( - """INSERT INTO local_current_membership - (room_id, user_id, event_id, membership) - VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) - """, - [ - (room_id, key[1], ev_id, ev_id) - for key, ev_id in to_insert.items() - if key[0] == EventTypes.Member and self.is_mine_id(key[1]) - ], - ) - - txn.call_after( - self.store._curr_state_delta_stream_cache.entity_has_changed, - room_id, - stream_id, - ) - - # Invalidate the various caches - - # Figure out the changes of membership to invalidate the - # `get_rooms_for_user` cache. - # We find out which membership events we may have deleted - # and which we have added, then we invlidate the caches for all - # those users. - members_changed = { - state_key - for ev_type, state_key in itertools.chain(to_delete, to_insert) - if ev_type == EventTypes.Member - } - - for member in members_changed: - txn.call_after( - self.store.get_rooms_for_user_with_stream_ordering.invalidate, - (member,), - ) - - self.store._invalidate_state_caches_and_stream( - txn, room_id, members_changed - ) - - def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): - """Update the room version in the database based off current state - events. - - This is used when we're about to delete current state and we want to - ensure that the `rooms.room_version` column is up to date. - """ - - sql = """ - SELECT json FROM event_json - INNER JOIN current_state_events USING (room_id, event_id) - WHERE room_id = ? AND type = ? AND state_key = ? - """ - txn.execute(sql, (room_id, EventTypes.Create, "")) - row = txn.fetchone() - if row: - event_json = db_to_json(row[0]) - content = event_json.get("content", {}) - creator = content.get("creator") - room_version_id = content.get("room_version", RoomVersions.V1.identifier) - - self.db.simple_upsert_txn( - txn, - table="rooms", - keyvalues={"room_id": room_id}, - values={"room_version": room_version_id}, - insertion_values={"is_public": False, "creator": creator}, - ) - - def _update_forward_extremities_txn( - self, txn, new_forward_extremities, max_stream_order - ): - for room_id, new_extrem in new_forward_extremities.items(): - self.db.simple_delete_txn( - txn, table="event_forward_extremities", keyvalues={"room_id": room_id} - ) - txn.call_after( - self.store.get_latest_event_ids_in_room.invalidate, (room_id,) - ) - - self.db.simple_insert_many_txn( - txn, - table="event_forward_extremities", - values=[ - {"event_id": ev_id, "room_id": room_id} - for room_id, new_extrem in new_forward_extremities.items() - for ev_id in new_extrem - ], - ) - # We now insert into stream_ordering_to_exterm a mapping from room_id, - # new stream_ordering to new forward extremeties in the room. - # This allows us to later efficiently look up the forward extremeties - # for a room before a given stream_ordering - self.db.simple_insert_many_txn( - txn, - table="stream_ordering_to_exterm", - values=[ - { - "room_id": room_id, - "event_id": event_id, - "stream_ordering": max_stream_order, - } - for room_id, new_extrem in new_forward_extremities.items() - for event_id in new_extrem - ], - ) - - @classmethod - def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): - """Ensure that we don't have the same event twice. - - Pick the earliest non-outlier if there is one, else the earliest one. - - Args: - events_and_contexts (list[(EventBase, EventContext)]): - Returns: - list[(EventBase, EventContext)]: filtered list - """ - new_events_and_contexts = OrderedDict() - for event, context in events_and_contexts: - prev_event_context = new_events_and_contexts.get(event.event_id) - if prev_event_context: - if not event.internal_metadata.is_outlier(): - if prev_event_context[0].internal_metadata.is_outlier(): - # To ensure correct ordering we pop, as OrderedDict is - # ordered by first insertion. - new_events_and_contexts.pop(event.event_id, None) - new_events_and_contexts[event.event_id] = (event, context) - else: - new_events_and_contexts[event.event_id] = (event, context) - return list(new_events_and_contexts.values()) - - def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): - """Update min_depth for each room - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - backfilled (bool): True if the events were backfilled - """ - depth_updates = {} - for event, context in events_and_contexts: - # Remove the any existing cache entries for the event_ids - txn.call_after(self.store._invalidate_get_event_cache, event.event_id) - if not backfilled: - txn.call_after( - self.store._events_stream_cache.entity_has_changed, - event.room_id, - event.internal_metadata.stream_ordering, - ) - - if not event.internal_metadata.is_outlier() and not context.rejected: - depth_updates[event.room_id] = max( - event.depth, depth_updates.get(event.room_id, event.depth) - ) - - for room_id, depth in depth_updates.items(): - self._update_min_depth_for_room_txn(txn, room_id, depth) - - def _update_outliers_txn(self, txn, events_and_contexts): - """Update any outliers with new event info. - - This turns outliers into ex-outliers (unless the new event was - rejected). - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - - Returns: - list[(EventBase, EventContext)] new list, without events which - are already in the events table. - """ - txn.execute( - "SELECT event_id, outlier FROM events WHERE event_id in (%s)" - % (",".join(["?"] * len(events_and_contexts)),), - [event.event_id for event, _ in events_and_contexts], - ) - - have_persisted = {event_id: outlier for event_id, outlier in txn} - - to_remove = set() - for event, context in events_and_contexts: - if event.event_id not in have_persisted: - continue - - to_remove.add(event) - - if context.rejected: - # If the event is rejected then we don't care if the event - # was an outlier or not. - continue - - outlier_persisted = have_persisted[event.event_id] - if not event.internal_metadata.is_outlier() and outlier_persisted: - # We received a copy of an event that we had already stored as - # an outlier in the database. We now have some state at that - # so we need to update the state_groups table with that state. - - # insert into event_to_state_groups. - try: - self._store_event_state_mappings_txn(txn, ((event, context),)) - except Exception: - logger.exception("") - raise - - metadata_json = encode_json(event.internal_metadata.get_dict()) - - sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" - txn.execute(sql, (metadata_json, event.event_id)) - - # Add an entry to the ex_outlier_stream table to replicate the - # change in outlier status to our workers. - stream_order = event.internal_metadata.stream_ordering - state_group_id = context.state_group - self.db.simple_insert_txn( - txn, - table="ex_outlier_stream", - values={ - "event_stream_ordering": stream_order, - "event_id": event.event_id, - "state_group": state_group_id, - }, - ) - - sql = "UPDATE events SET outlier = ? WHERE event_id = ?" - txn.execute(sql, (False, event.event_id)) - - # Update the event_backward_extremities table now that this - # event isn't an outlier any more. - self._update_backward_extremeties(txn, [event]) - - return [ec for ec in events_and_contexts if ec[0] not in to_remove] - - def _store_event_txn(self, txn, events_and_contexts): - """Insert new events into the event and event_json tables - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - """ - - if not events_and_contexts: - # nothing to do here - return - - def event_dict(event): - d = event.get_dict() - d.pop("redacted", None) - d.pop("redacted_because", None) - return d - - self.db.simple_insert_many_txn( - txn, - table="event_json", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "internal_metadata": encode_json( - event.internal_metadata.get_dict() - ), - "json": encode_json(event_dict(event)), - "format_version": event.format_version, - } - for event, _ in events_and_contexts - ], - ) - - self.db.simple_insert_many_txn( - txn, - table="events", - values=[ - { - "stream_ordering": event.internal_metadata.stream_ordering, - "topological_ordering": event.depth, - "depth": event.depth, - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "processed": True, - "outlier": event.internal_metadata.is_outlier(), - "origin_server_ts": int(event.origin_server_ts), - "received_ts": self._clock.time_msec(), - "sender": event.sender, - "contains_url": ( - "url" in event.content and isinstance(event.content["url"], str) - ), - "count_as_unread": should_count_as_unread(event, context), - } - for event, context in events_and_contexts - ], - ) - - for event, _ in events_and_contexts: - if not event.internal_metadata.is_redacted(): - # If we're persisting an unredacted event we go and ensure - # that we mark any redactions that reference this event as - # requiring censoring. - self.db.simple_update_txn( - txn, - table="redactions", - keyvalues={"redacts": event.event_id}, - updatevalues={"have_censored": False}, - ) - - def _store_rejected_events_txn(self, txn, events_and_contexts): - """Add rows to the 'rejections' table for received events which were - rejected - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - - Returns: - list[(EventBase, EventContext)] new list, without the rejected - events. - """ - # Remove the rejected events from the list now that we've added them - # to the events table and the events_json table. - to_remove = set() - for event, context in events_and_contexts: - if context.rejected: - # Insert the event_id into the rejections table - self._store_rejections_txn(txn, event.event_id, context.rejected) - to_remove.add(event) - - return [ec for ec in events_and_contexts if ec[0] not in to_remove] - - def _update_metadata_tables_txn( - self, txn, events_and_contexts, all_events_and_contexts, backfilled - ): - """Update all the miscellaneous tables for new events - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. - backfilled (bool): True if the events were backfilled - """ - - # Insert all the push actions into the event_push_actions table. - self._set_push_actions_for_event_and_users_txn( - txn, - events_and_contexts=events_and_contexts, - all_events_and_contexts=all_events_and_contexts, - ) - - if not events_and_contexts: - # nothing to do here - return - - for event, context in events_and_contexts: - if event.type == EventTypes.Redaction and event.redacts is not None: - # Remove the entries in the event_push_actions table for the - # redacted event. - self._remove_push_actions_for_event_id_txn( - txn, event.room_id, event.redacts - ) - - # Remove from relations table. - self._handle_redaction(txn, event.redacts) - - # Update the event_forward_extremities, event_backward_extremities and - # event_edges tables. - self._handle_mult_prev_events( - txn, events=[event for event, _ in events_and_contexts] - ) - - for event, _ in events_and_contexts: - if event.type == EventTypes.Name: - # Insert into the event_search table. - self._store_room_name_txn(txn, event) - elif event.type == EventTypes.Topic: - # Insert into the event_search table. - self._store_room_topic_txn(txn, event) - elif event.type == EventTypes.Message: - # Insert into the event_search table. - self._store_room_message_txn(txn, event) - elif event.type == EventTypes.Redaction and event.redacts is not None: - # Insert into the redactions table. - self._store_redaction(txn, event) - elif event.type == EventTypes.Retention: - # Update the room_retention table. - self._store_retention_policy_for_room_txn(txn, event) - - self._handle_event_relations(txn, event) - - # Store the labels for this event. - labels = event.content.get(EventContentFields.LABELS) - if labels: - self.insert_labels_for_event_txn( - txn, event.event_id, labels, event.room_id, event.depth - ) - - if self._ephemeral_messages_enabled: - # If there's an expiry timestamp on the event, store it. - expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if isinstance(expiry_ts, int) and not event.is_state(): - self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) - - # Insert into the room_memberships table. - self._store_room_members_txn( - txn, - [ - event - for event, _ in events_and_contexts - if event.type == EventTypes.Member - ], - backfilled=backfilled, - ) - - # Insert event_reference_hashes table. - self._store_event_reference_hashes_txn( - txn, [event for event, _ in events_and_contexts] - ) - - state_events_and_contexts = [ - ec for ec in events_and_contexts if ec[0].is_state() - ] - - state_values = [] - for event, context in state_events_and_contexts: - vals = { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } - - # TODO: How does this work with backfilling? - if hasattr(event, "replaces_state"): - vals["prev_state"] = event.replaces_state - - state_values.append(vals) - - self.db.simple_insert_many_txn(txn, table="state_events", values=state_values) - - # Prefill the event cache - self._add_to_cache(txn, events_and_contexts) - - def _add_to_cache(self, txn, events_and_contexts): - to_prefill = [] - - rows = [] - N = 200 - for i in range(0, len(events_and_contexts), N): - ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]} - if not ev_map: - break - - sql = ( - "SELECT " - " e.event_id as event_id, " - " r.redacts as redacts," - " rej.event_id as rejects " - " FROM events as e" - " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE " - ) - - clause, args = make_in_list_sql_clause( - self.database_engine, "e.event_id", list(ev_map) - ) - - txn.execute(sql + clause, args) - rows = self.db.cursor_to_dict(txn) - for row in rows: - event = ev_map[row["event_id"]] - if not row["rejects"] and not row["redacts"]: - to_prefill.append( - _EventCacheEntry(event=event, redacted_event=None) - ) - - def prefill(): - for cache_entry in to_prefill: - self.store._get_event_cache.prefill( - (cache_entry[0].event_id,), cache_entry - ) - - txn.call_after(prefill) - - def _store_redaction(self, txn, event): - # invalidate the cache for the redacted event - txn.call_after(self.store._invalidate_get_event_cache, event.redacts) - - self.db.simple_insert_txn( - txn, - table="redactions", - values={ - "event_id": event.event_id, - "redacts": event.redacts, - "received_ts": self._clock.time_msec(), - }, - ) - - def insert_labels_for_event_txn( - self, txn, event_id, labels, room_id, topological_ordering - ): - """Store the mapping between an event's ID and its labels, with one row per - (event_id, label) tuple. - - Args: - txn (LoggingTransaction): The transaction to execute. - event_id (str): The event's ID. - labels (list[str]): A list of text labels. - room_id (str): The ID of the room the event was sent to. - topological_ordering (int): The position of the event in the room's topology. - """ - return self.db.simple_insert_many_txn( - txn=txn, - table="event_labels", - values=[ - { - "event_id": event_id, - "label": label, - "room_id": room_id, - "topological_ordering": topological_ordering, - } - for label in labels - ], - ) - - def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): - """Save the expiry timestamp associated with a given event ID. - - Args: - txn (LoggingTransaction): The database transaction to use. - event_id (str): The event ID the expiry timestamp is associated with. - expiry_ts (int): The timestamp at which to expire (delete) the event. - """ - return self.db.simple_insert_txn( - txn=txn, - table="event_expiry", - values={"event_id": event_id, "expiry_ts": expiry_ts}, - ) - - def _store_event_reference_hashes_txn(self, txn, events): - """Store a hash for a PDU - Args: - txn (cursor): - events (list): list of Events. - """ - - vals = [] - for event in events: - ref_alg, ref_hash_bytes = compute_event_reference_hash(event) - vals.append( - { - "event_id": event.event_id, - "algorithm": ref_alg, - "hash": memoryview(ref_hash_bytes), - } - ) - - self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) - - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self.db.simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ], - ) - - for event in events: - txn.call_after( - self.store._membership_stream_cache.entity_has_changed, - event.state_key, - event.internal_metadata.stream_ordering, - ) - txn.call_after( - self.store.get_invited_rooms_for_local_user.invalidate, - (event.state_key,), - ) - - # We update the local_current_membership table only if the event is - # "current", i.e., its something that has just happened. - # - # This will usually get updated by the `current_state_events` handling, - # unless its an outlier, and an outlier is only "current" if it's an "out of - # band membership", like a remote invite or a rejection of a remote invite. - if ( - self.is_mine_id(event.state_key) - and not backfilled - and event.internal_metadata.is_outlier() - and event.internal_metadata.is_out_of_band_membership() - ): - self.db.simple_upsert_txn( - txn, - table="local_current_membership", - keyvalues={"room_id": event.room_id, "user_id": event.state_key}, - values={ - "event_id": event.event_id, - "membership": event.membership, - }, - ) - - def _handle_event_relations(self, txn, event): - """Handles inserting relation data during peristence of events - - Args: - txn - event (EventBase) - """ - relation = event.content.get("m.relates_to") - if not relation: - # No relations - return - - rel_type = relation.get("rel_type") - if rel_type not in ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.REPLACE, - ): - # Unknown relation type - return - - parent_id = relation.get("event_id") - if not parent_id: - # Invalid relation - return - - aggregation_key = relation.get("key") - - self.db.simple_insert_txn( - txn, - table="event_relations", - values={ - "event_id": event.event_id, - "relates_to_id": parent_id, - "relation_type": rel_type, - "aggregation_key": aggregation_key, - }, - ) - - txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,)) - txn.call_after( - self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,) - ) - - if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) - - def _handle_redaction(self, txn, redacted_event_id): - """Handles receiving a redaction and checking whether we need to remove - any redacted relations from the database. - - Args: - txn - redacted_event_id (str): The event that was redacted. - """ - - self.db.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) - - def _store_room_topic_txn(self, txn, event): - if hasattr(event, "content") and "topic" in event.content: - self.store_event_search_txn( - txn, event, "content.topic", event.content["topic"] - ) - - def _store_room_name_txn(self, txn, event): - if hasattr(event, "content") and "name" in event.content: - self.store_event_search_txn( - txn, event, "content.name", event.content["name"] - ) - - def _store_room_message_txn(self, txn, event): - if hasattr(event, "content") and "body" in event.content: - self.store_event_search_txn( - txn, event, "content.body", event.content["body"] - ) - - def _store_retention_policy_for_room_txn(self, txn, event): - if hasattr(event, "content") and ( - "min_lifetime" in event.content or "max_lifetime" in event.content - ): - if ( - "min_lifetime" in event.content - and not isinstance(event.content.get("min_lifetime"), int) - ) or ( - "max_lifetime" in event.content - and not isinstance(event.content.get("max_lifetime"), int) - ): - # Ignore the event if one of the value isn't an integer. - return - - self.db.simple_insert_txn( - txn=txn, - table="room_retention", - values={ - "room_id": event.room_id, - "event_id": event.event_id, - "min_lifetime": event.content.get("min_lifetime"), - "max_lifetime": event.content.get("max_lifetime"), - }, - ) - - self.store._invalidate_cache_and_stream( - txn, self.store.get_retention_policy_for_room, (event.room_id,) - ) - - def store_event_search_txn(self, txn, event, key, value): - """Add event to the search table - - Args: - txn (cursor): - event (EventBase): - key (str): - value (str): - """ - self.store.store_search_entries_txn( - txn, - ( - SearchEntry( - key=key, - value=value, - event_id=event.event_id, - room_id=event.room_id, - stream_ordering=event.internal_metadata.stream_ordering, - origin_server_ts=event.origin_server_ts, - ), - ), - ) - - def _set_push_actions_for_event_and_users_txn( - self, txn, events_and_contexts, all_events_and_contexts - ): - """Handles moving push actions from staging table to main - event_push_actions table for all events in `events_and_contexts`. - - Also ensures that all events in `all_events_and_contexts` are removed - from the push action staging area. - - Args: - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. - """ - - sql = """ - INSERT INTO event_push_actions ( - room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight - ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight - FROM event_push_actions_staging - WHERE event_id = ? - """ - - if events_and_contexts: - txn.executemany( - sql, - ( - ( - event.room_id, - event.internal_metadata.stream_ordering, - event.depth, - event.event_id, - ) - for event, _ in events_and_contexts - ), - ) - - for event, _ in events_and_contexts: - user_ids = self.db.simple_select_onecol_txn( - txn, - table="event_push_actions_staging", - keyvalues={"event_id": event.event_id}, - retcol="user_id", - ) - - for uid in user_ids: - txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid), - ) - - # Now we delete the staging area for *all* events that were being - # persisted. - txn.executemany( - "DELETE FROM event_push_actions_staging WHERE event_id = ?", - ((event.event_id,) for event, _ in all_events_and_contexts), - ) - - def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): - # Sad that we have to blow away the cache for the whole room here - txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id,), - ) - txn.execute( - "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", - (room_id, event_id), - ) - - def _store_rejections_txn(self, txn, event_id, reason): - self.db.simple_insert_txn( - txn, - table="rejections", - values={ - "event_id": event_id, - "reason": reason, - "last_check": self._clock.time_msec(), - }, - ) - - def _store_event_state_mappings_txn( - self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] - ): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.state_group_before_event - continue - - state_groups[event.event_id] = context.state_group - - self.db.simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in state_groups.items() - ], - ) - - for event_id, state_group_id in state_groups.items(): - txn.call_after( - self.store._get_state_group_for_event.prefill, - (event_id,), - state_group_id, - ) - - def _update_min_depth_for_room_txn(self, txn, room_id, depth): - min_depth = self.store._get_min_depth_interaction(txn, room_id) - - if min_depth is not None and depth >= min_depth: - return - - self.db.simple_upsert_txn( - txn, - table="room_depth", - keyvalues={"room_id": room_id}, - values={"min_depth": depth}, - ) - - def _handle_mult_prev_events(self, txn, events): - """ - For the given event, update the event edges table and forward and - backward extremities tables. - """ - self.db.simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": ev.event_id, - "prev_event_id": e_id, - "room_id": ev.room_id, - "is_state": False, - } - for ev in events - for e_id in ev.prev_event_ids() - ], - ) - - self._update_backward_extremeties(txn, events) - - def _update_backward_extremeties(self, txn, events): - """Updates the event_backward_extremities tables based on the new/updated - events being persisted. - - This is called for new events *and* for events that were outliers, but - are now being persisted as non-outliers. - - Forward extremities are handled when we first start persisting the events. - """ - events_by_room = {} - for ev in events: - events_by_room.setdefault(ev.room_id, []).append(ev) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) - - txn.executemany( - query, - [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) - for ev in events - for e_id in ev.prev_event_ids() - if not ev.internal_metadata.is_outlier() - ], - ) - - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.executemany( - query, - [ - (ev.event_id, ev.room_id) - for ev in events - if not ev.internal_metadata.is_outlier() - ], - ) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py deleted file mode 100644 index 663c94b24f..0000000000 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ /dev/null @@ -1,585 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from twisted.internet import defer - -from synapse.api.constants import EventContentFields -from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database - -logger = logging.getLogger(__name__) - - -class EventsBackgroundUpdatesStore(SQLBaseStore): - - EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" - EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - - def __init__(self, database: Database, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) - - self.db.updates.register_background_update_handler( - self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts - ) - self.db.updates.register_background_update_handler( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, - self._background_reindex_fields_sender, - ) - - self.db.updates.register_background_index_update( - "event_contains_url_index", - index_name="event_contains_url_index", - table="events", - columns=["room_id", "topological_ordering", "stream_ordering"], - where_clause="contains_url = true AND outlier = false", - ) - - # an event_id index on event_search is useful for the purge_history - # api. Plus it means we get to enforce some integrity with a UNIQUE - # clause - self.db.updates.register_background_index_update( - "event_search_event_id_idx", - index_name="event_search_event_id_idx", - table="event_search", - columns=["event_id"], - unique=True, - psql_only=True, - ) - - self.db.updates.register_background_update_handler( - self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update - ) - - self.db.updates.register_background_update_handler( - "redactions_received_ts", self._redactions_received_ts - ) - - # This index gets deleted in `event_fix_redactions_bytes` update - self.db.updates.register_background_index_update( - "event_fix_redactions_bytes_create_index", - index_name="redactions_censored_redacts", - table="redactions", - columns=["redacts"], - where_clause="have_censored", - ) - - self.db.updates.register_background_update_handler( - "event_fix_redactions_bytes", self._event_fix_redactions_bytes - ) - - self.db.updates.register_background_update_handler( - "event_store_labels", self._event_store_labels - ) - - self.db.updates.register_background_index_update( - "redactions_have_censored_ts_idx", - index_name="redactions_have_censored_ts", - table="redactions", - columns=["received_ts"], - where_clause="NOT have_censored", - ) - - @defer.inlineCallbacks - def _background_reindex_fields_sender(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - INSERT_CLUMP_SIZE = 1000 - - def reindex_txn(txn): - sql = ( - "SELECT stream_ordering, event_id, json FROM events" - " INNER JOIN event_json USING (event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - min_stream_id = rows[-1][0] - - update_rows = [] - for row in rows: - try: - event_id = row[1] - event_json = db_to_json(row[2]) - sender = event_json["sender"] - content = event_json["content"] - - contains_url = "url" in content - if contains_url: - contains_url &= isinstance(content["url"], str) - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - update_rows.append((sender, contains_url, event_id)) - - sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - - for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows), - } - - self.db.updates._background_update_progress_txn( - txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress - ) - - return len(rows) - - result = yield self.db.runInteraction( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn - ) - - if not result: - yield self.db.updates._end_background_update( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME - ) - - return result - - @defer.inlineCallbacks - def _background_reindex_origin_server_ts(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - INSERT_CLUMP_SIZE = 1000 - - def reindex_search_txn(txn): - sql = ( - "SELECT stream_ordering, event_id FROM events" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - min_stream_id = rows[-1][0] - event_ids = [row[1] for row in rows] - - rows_to_update = [] - - chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] - for chunk in chunks: - ev_rows = self.db.simple_select_many_txn( - txn, - table="event_json", - column="event_id", - iterable=chunk, - retcols=["event_id", "json"], - keyvalues={}, - ) - - for row in ev_rows: - event_id = row["event_id"] - event_json = db_to_json(row["json"]) - try: - origin_server_ts = event_json["origin_server_ts"] - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - rows_to_update.append((origin_server_ts, event_id)) - - sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - - for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows_to_update), - } - - self.db.updates._background_update_progress_txn( - txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress - ) - - return len(rows_to_update) - - result = yield self.db.runInteraction( - self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn - ) - - if not result: - yield self.db.updates._end_background_update( - self.EVENT_ORIGIN_SERVER_TS_NAME - ) - - return result - - @defer.inlineCallbacks - def _cleanup_extremities_bg_update(self, progress, batch_size): - """Background update to clean out extremities that should have been - deleted previously. - - Mainly used to deal with the aftermath of #5269. - """ - - # This works by first copying all existing forward extremities into the - # `_extremities_to_check` table at start up, and then checking each - # event in that table whether we have any descendants that are not - # soft-failed/rejected. If that is the case then we delete that event - # from the forward extremities table. - # - # For efficiency, we do this in batches by recursively pulling out all - # descendants of a batch until we find the non soft-failed/rejected - # events, i.e. the set of descendants whose chain of prev events back - # to the batch of extremities are all soft-failed or rejected. - # Typically, we won't find any such events as extremities will rarely - # have any descendants, but if they do then we should delete those - # extremities. - - def _cleanup_extremities_bg_update_txn(txn): - # The set of extremity event IDs that we're checking this round - original_set = set() - - # A dict[str, set[str]] of event ID to their prev events. - graph = {} - - # The set of descendants of the original set that are not rejected - # nor soft-failed. Ancestors of these events should be removed - # from the forward extremities table. - non_rejected_leaves = set() - - # Set of event IDs that have been soft failed, and for which we - # should check if they have descendants which haven't been soft - # failed. - soft_failed_events_to_lookup = set() - - # First, we get `batch_size` events from the table, pulling out - # their successor events, if any, and the successor events' - # rejection status. - txn.execute( - """SELECT prev_event_id, event_id, internal_metadata, - rejections.event_id IS NOT NULL, events.outlier - FROM ( - SELECT event_id AS prev_event_id - FROM _extremities_to_check - LIMIT ? - ) AS f - LEFT JOIN event_edges USING (prev_event_id) - LEFT JOIN events USING (event_id) - LEFT JOIN event_json USING (event_id) - LEFT JOIN rejections USING (event_id) - """, - (batch_size,), - ) - - for prev_event_id, event_id, metadata, rejected, outlier in txn: - original_set.add(prev_event_id) - - if not event_id or outlier: - # Common case where the forward extremity doesn't have any - # descendants. - continue - - graph.setdefault(event_id, set()).add(prev_event_id) - - soft_failed = False - if metadata: - soft_failed = db_to_json(metadata).get("soft_failed") - - if soft_failed or rejected: - soft_failed_events_to_lookup.add(event_id) - else: - non_rejected_leaves.add(event_id) - - # Now we recursively check all the soft-failed descendants we - # found above in the same way, until we have nothing left to - # check. - while soft_failed_events_to_lookup: - # We only want to do 100 at a time, so we split given list - # into two. - batch = list(soft_failed_events_to_lookup) - to_check, to_defer = batch[:100], batch[100:] - soft_failed_events_to_lookup = set(to_defer) - - sql = """SELECT prev_event_id, event_id, internal_metadata, - rejections.event_id IS NOT NULL - FROM event_edges - INNER JOIN events USING (event_id) - INNER JOIN event_json USING (event_id) - LEFT JOIN rejections USING (event_id) - WHERE - NOT events.outlier - AND - """ - clause, args = make_in_list_sql_clause( - self.database_engine, "prev_event_id", to_check - ) - txn.execute(sql + clause, list(args)) - - for prev_event_id, event_id, metadata, rejected in txn: - if event_id in graph: - # Already handled this event previously, but we still - # want to record the edge. - graph[event_id].add(prev_event_id) - continue - - graph[event_id] = {prev_event_id} - - soft_failed = db_to_json(metadata).get("soft_failed") - if soft_failed or rejected: - soft_failed_events_to_lookup.add(event_id) - else: - non_rejected_leaves.add(event_id) - - # We have a set of non-soft-failed descendants, so we recurse up - # the graph to find all ancestors and add them to the set of event - # IDs that we can delete from forward extremities table. - to_delete = set() - while non_rejected_leaves: - event_id = non_rejected_leaves.pop() - prev_event_ids = graph.get(event_id, set()) - non_rejected_leaves.update(prev_event_ids) - to_delete.update(prev_event_ids) - - to_delete.intersection_update(original_set) - - deleted = self.db.simple_delete_many_txn( - txn=txn, - table="event_forward_extremities", - column="event_id", - iterable=to_delete, - keyvalues={}, - ) - - logger.info( - "Deleted %d forward extremities of %d checked, to clean up #5269", - deleted, - len(original_set), - ) - - if deleted: - # We now need to invalidate the caches of these rooms - rows = self.db.simple_select_many_txn( - txn, - table="events", - column="event_id", - iterable=to_delete, - keyvalues={}, - retcols=("room_id",), - ) - room_ids = {row["room_id"] for row in rows} - for room_id in room_ids: - txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) - ) - - self.db.simple_delete_many_txn( - txn=txn, - table="_extremities_to_check", - column="event_id", - iterable=original_set, - keyvalues={}, - ) - - return len(original_set) - - num_handled = yield self.db.runInteraction( - "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn - ) - - if not num_handled: - yield self.db.updates._end_background_update( - self.DELETE_SOFT_FAILED_EXTREMITIES - ) - - def _drop_table_txn(txn): - txn.execute("DROP TABLE _extremities_to_check") - - yield self.db.runInteraction( - "_cleanup_extremities_bg_update_drop_table", _drop_table_txn - ) - - return num_handled - - @defer.inlineCallbacks - def _redactions_received_ts(self, progress, batch_size): - """Handles filling out the `received_ts` column in redactions. - """ - last_event_id = progress.get("last_event_id", "") - - def _redactions_received_ts_txn(txn): - # Fetch the set of event IDs that we want to update - sql = """ - SELECT event_id FROM redactions - WHERE event_id > ? - ORDER BY event_id ASC - LIMIT ? - """ - - txn.execute(sql, (last_event_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - (upper_event_id,) = rows[-1] - - # Update the redactions with the received_ts. - # - # Note: Not all events have an associated received_ts, so we - # fallback to using origin_server_ts. If we for some reason don't - # have an origin_server_ts, lets just use the current timestamp. - # - # We don't want to leave it null, as then we'll never try and - # censor those redactions. - sql = """ - UPDATE redactions - SET received_ts = ( - SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events - WHERE events.event_id = redactions.event_id - ) - WHERE ? <= event_id AND event_id <= ? - """ - - txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) - - self.db.updates._background_update_progress_txn( - txn, "redactions_received_ts", {"last_event_id": upper_event_id} - ) - - return len(rows) - - count = yield self.db.runInteraction( - "_redactions_received_ts", _redactions_received_ts_txn - ) - - if not count: - yield self.db.updates._end_background_update("redactions_received_ts") - - return count - - @defer.inlineCallbacks - def _event_fix_redactions_bytes(self, progress, batch_size): - """Undoes hex encoded censored redacted event JSON. - """ - - def _event_fix_redactions_bytes_txn(txn): - # This update is quite fast due to new index. - txn.execute( - """ - UPDATE event_json - SET - json = convert_from(json::bytea, 'utf8') - FROM redactions - WHERE - redactions.have_censored - AND event_json.event_id = redactions.redacts - AND json NOT LIKE '{%'; - """ - ) - - txn.execute("DROP INDEX redactions_censored_redacts") - - yield self.db.runInteraction( - "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn - ) - - yield self.db.updates._end_background_update("event_fix_redactions_bytes") - - return 1 - - @defer.inlineCallbacks - def _event_store_labels(self, progress, batch_size): - """Background update handler which will store labels for existing events.""" - last_event_id = progress.get("last_event_id", "") - - def _event_store_labels_txn(txn): - txn.execute( - """ - SELECT event_id, json FROM event_json - LEFT JOIN event_labels USING (event_id) - WHERE event_id > ? AND label IS NULL - ORDER BY event_id LIMIT ? - """, - (last_event_id, batch_size), - ) - - results = list(txn) - - nbrows = 0 - last_row_event_id = "" - for (event_id, event_json_raw) in results: - try: - event_json = db_to_json(event_json_raw) - - self.db.simple_insert_many_txn( - txn=txn, - table="event_labels", - values=[ - { - "event_id": event_id, - "label": label, - "room_id": event_json["room_id"], - "topological_ordering": event_json["depth"], - } - for label in event_json["content"].get( - EventContentFields.LABELS, [] - ) - if isinstance(label, str) - ], - ) - except Exception as e: - logger.warning( - "Unable to load event %s (no labels will be imported): %s", - event_id, - e, - ) - - nbrows += 1 - last_row_event_id = event_id - - self.db.updates._background_update_progress_txn( - txn, "event_store_labels", {"last_event_id": last_row_event_id} - ) - - return nbrows - - num_rows = yield self.db.runInteraction( - desc="event_store_labels", func=_event_store_labels_txn - ) - - if not num_rows: - yield self.db.updates._end_background_update("event_store_labels") - - return num_rows diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py deleted file mode 100644 index b03b259636..0000000000 --- a/synapse/storage/data_stores/main/events_worker.py +++ /dev/null @@ -1,1454 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import division - -import itertools -import logging -import threading -from collections import namedtuple -from typing import List, Optional, Tuple - -from constantly import NamedConstant, Names - -from twisted.internet import defer - -from synapse.api.constants import EventTypes -from synapse.api.errors import NotFoundError, SynapseError -from synapse.api.room_versions import ( - KNOWN_ROOM_VERSIONS, - EventFormatVersions, - RoomVersions, -) -from synapse.events import make_event_from_dict -from synapse.events.utils import prune_event -from synapse.logging.context import PreserveLoggingContext, current_context -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import BackfillStream -from synapse.replication.tcp.streams.events import EventsStream -from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database -from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import ( - Cache, - _CacheContext, - cached, - cachedInlineCallbacks, -) -from synapse.util.iterutils import batch_iter -from synapse.util.metrics import Measure - -logger = logging.getLogger(__name__) - - -# These values are used in the `enqueus_event` and `_do_fetch` methods to -# control how we batch/bulk fetch events from the database. -# The values are plucked out of thing air to make initial sync run faster -# on jki.re -# TODO: Make these configurable. -EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events -EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events -EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events - - -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) - - -class EventRedactBehaviour(Names): - """ - What to do when retrieving a redacted event from the database. - """ - - AS_IS = NamedConstant() - REDACT = NamedConstant() - BLOCK = NamedConstant() - - -class EventsWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - super(EventsWorkerStore, self).__init__(database, db_conn, hs) - - if hs.config.worker.writers.events == hs.get_instance_name(): - # We are the process in charge of generating stream ids for events, - # so instantiate ID generators based on the database - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - # Another process is in charge of persisting events and generating - # stream IDs: rely on the replication streams to let us know which - # IDs we can process. - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) - - self._get_event_cache = Cache( - "*getEvent*", - keylen=3, - max_entries=hs.config.caches.event_cache_size, - apply_cache_factor_from_config=False, - ) - - self._event_fetch_lock = threading.Condition() - self._event_fetch_list = [] - self._event_fetch_ongoing = 0 - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(token) - elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(-token) - - super().process_replication_rows(stream_name, instance_name, token, rows) - - def get_received_ts(self, event_id): - """Get received_ts (when it was persisted) for the event. - - Raises an exception for unknown events. - - Args: - event_id (str) - - Returns: - Deferred[int|None]: Timestamp in milliseconds, or None for events - that were persisted before received_ts was implemented. - """ - return self.db.simple_select_one_onecol( - table="events", - keyvalues={"event_id": event_id}, - retcol="received_ts", - desc="get_received_ts", - ) - - def get_received_ts_by_stream_pos(self, stream_ordering): - """Given a stream ordering get an approximate timestamp of when it - happened. - - This is done by simply taking the received ts of the first event that - has a stream ordering greater than or equal to the given stream pos. - If none exists returns the current time, on the assumption that it must - have happened recently. - - Args: - stream_ordering (int) - - Returns: - Deferred[int] - """ - - def _get_approximate_received_ts_txn(txn): - sql = """ - SELECT received_ts FROM events - WHERE stream_ordering >= ? - LIMIT 1 - """ - - txn.execute(sql, (stream_ordering,)) - row = txn.fetchone() - if row and row[0]: - ts = row[0] - else: - ts = self.clock.time_msec() - - return ts - - return self.db.runInteraction( - "get_approximate_received_ts", _get_approximate_received_ts_txn - ) - - @defer.inlineCallbacks - def get_event( - self, - event_id: str, - redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = False, - allow_rejected: bool = False, - allow_none: bool = False, - check_room_id: Optional[str] = None, - ): - """Get an event from the database by event_id. - - Args: - event_id: The event_id of the event to fetch - - redact_behaviour: Determine what to do with a redacted event. Possible values: - * AS_IS - Return the full event body with no redacted content - * REDACT - Return the event but with a redacted body - * DISALLOW - Do not return redacted events (behave as per allow_none - if the event is redacted) - - get_prev_content: If True and event is a state event, - include the previous states content in the unsigned field. - - allow_rejected: If True, return rejected events. Otherwise, - behave as per allow_none. - - allow_none: If True, return None if no event found, if - False throw a NotFoundError - - check_room_id: if not None, check the room of the found event. - If there is a mismatch, behave as per allow_none. - - Returns: - Deferred[EventBase|None] - """ - if not isinstance(event_id, str): - raise TypeError("Invalid event event_id %r" % (event_id,)) - - events = yield self.get_events_as_list( - [event_id], - redact_behaviour=redact_behaviour, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - event = events[0] if events else None - - if event is not None and check_room_id is not None: - if event.room_id != check_room_id: - event = None - - if event is None and not allow_none: - raise NotFoundError("Could not find event %s" % (event_id,)) - - return event - - @defer.inlineCallbacks - def get_events( - self, - event_ids: List[str], - redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = False, - allow_rejected: bool = False, - ): - """Get events from the database - - Args: - event_ids: The event_ids of the events to fetch - - redact_behaviour: Determine what to do with a redacted event. Possible - values: - * AS_IS - Return the full event body with no redacted content - * REDACT - Return the event but with a redacted body - * DISALLOW - Do not return redacted events (omit them from the response) - - get_prev_content: If True and event is a state event, - include the previous states content in the unsigned field. - - allow_rejected: If True, return rejected events. Otherwise, - omits rejeted events from the response. - - Returns: - Deferred : Dict from event_id to event. - """ - events = yield self.get_events_as_list( - event_ids, - redact_behaviour=redact_behaviour, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - return {e.event_id: e for e in events} - - @defer.inlineCallbacks - def get_events_as_list( - self, - event_ids: List[str], - redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = False, - allow_rejected: bool = False, - ): - """Get events from the database and return in a list in the same order - as given by `event_ids` arg. - - Unknown events will be omitted from the response. - - Args: - event_ids: The event_ids of the events to fetch - - redact_behaviour: Determine what to do with a redacted event. Possible values: - * AS_IS - Return the full event body with no redacted content - * REDACT - Return the event but with a redacted body - * DISALLOW - Do not return redacted events (omit them from the response) - - get_prev_content: If True and event is a state event, - include the previous states content in the unsigned field. - - allow_rejected: If True, return rejected events. Otherwise, - omits rejected events from the response. - - Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. - - Note that the returned list may be smaller than the list of event - IDs if not all events could be fetched. - """ - - if not event_ids: - return [] - - # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( - set(event_ids), allow_rejected=allow_rejected - ) - - events = [] - for event_id in event_ids: - entry = event_entry_map.get(event_id, None) - if not entry: - continue - - if not allow_rejected: - assert not entry.event.rejected_reason, ( - "rejected event returned from _get_events_from_cache_or_db despite " - "allow_rejected=False" - ) - - # We may not have had the original event when we received a redaction, so - # we have to recheck auth now. - - if not allow_rejected and entry.event.type == EventTypes.Redaction: - if entry.event.redacts is None: - # A redacted redaction doesn't have a `redacts` key, in - # which case lets just withhold the event. - # - # Note: Most of the time if the redactions has been - # redacted we still have the un-redacted event in the DB - # and so we'll still see the `redacts` key. However, this - # isn't always true e.g. if we have censored the event. - logger.debug( - "Withholding redaction event %s as we don't have redacts key", - event_id, - ) - continue - - redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) - original_event_entry = event_map.get(redacted_event_id) - if not original_event_entry: - # we don't have the redacted event (or it was rejected). - # - # We assume that the redaction isn't authorized for now; if the - # redacted event later turns up, the redaction will be re-checked, - # and if it is found valid, the original will get redacted before it - # is served to the client. - logger.debug( - "Withholding redaction event %s since we don't (yet) have the " - "original %s", - event_id, - redacted_event_id, - ) - continue - - original_event = original_event_entry.event - if original_event.type == EventTypes.Create: - # we never serve redactions of Creates to clients. - logger.info( - "Withholding redaction %s of create event %s", - event_id, - redacted_event_id, - ) - continue - - if original_event.room_id != entry.event.room_id: - logger.info( - "Withholding redaction %s of event %s from a different room", - event_id, - redacted_event_id, - ) - continue - - if entry.event.internal_metadata.need_to_check_redaction(): - original_domain = get_domain_from_id(original_event.sender) - redaction_domain = get_domain_from_id(entry.event.sender) - if original_domain != redaction_domain: - # the senders don't match, so this is forbidden - logger.info( - "Withholding redaction %s whose sender domain %s doesn't " - "match that of redacted event %s %s", - event_id, - redaction_domain, - redacted_event_id, - original_domain, - ) - continue - - # Update the cache to save doing the checks again. - entry.event.internal_metadata.recheck_redaction = False - - event = entry.event - - if entry.redacted_event: - if redact_behaviour == EventRedactBehaviour.BLOCK: - # Skip this event - continue - elif redact_behaviour == EventRedactBehaviour.REDACT: - event = entry.redacted_event - - events.append(event) - - if get_prev_content: - if "replaces_state" in event.unsigned: - prev = yield self.get_event( - event.unsigned["replaces_state"], - get_prev_content=False, - allow_none=True, - ) - if prev: - event.unsigned = dict(event.unsigned) - event.unsigned["prev_content"] = prev.content - event.unsigned["prev_sender"] = prev.sender - - return events - - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): - """Fetch a bunch of events from the cache or the database. - - If events are pulled from the database, they will be cached for future lookups. - - Unknown events are omitted from the response. - - Args: - - event_ids (Iterable[str]): The event_ids of the events to fetch - - allow_rejected (bool): Whether to include rejected events. If False, - rejected events are omitted from the response. - - Returns: - Deferred[Dict[str, _EventCacheEntry]]: - map from event id to result - """ - event_entry_map = self._get_events_from_cache( - event_ids, allow_rejected=allow_rejected - ) - - missing_events_ids = [e for e in event_ids if e not in event_entry_map] - - if missing_events_ids: - log_ctx = current_context() - log_ctx.record_event_fetch(len(missing_events_ids)) - - # Note that _get_events_from_db is also responsible for turning db rows - # into FrozenEvents (via _get_event_from_row), which involves seeing if - # the events have been redacted, and if so pulling the redaction event out - # of the database to check it. - # - missing_events = yield self._get_events_from_db( - missing_events_ids, allow_rejected=allow_rejected - ) - - event_entry_map.update(missing_events) - - return event_entry_map - - def _invalidate_get_event_cache(self, event_id): - self._get_event_cache.invalidate((event_id,)) - - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches - - Args: - events (Iterable[str]): list of event_ids to fetch - allow_rejected (bool): Whether to return events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics - - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` - """ - event_map = {} - - for event_id in events: - ret = self._get_event_cache.get( - (event_id,), None, update_metrics=update_metrics - ) - if not ret: - continue - - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None - - return event_map - - def _do_fetch(self, conn): - """Takes a database connection and waits for requests for events from - the _event_fetch_list queue. - """ - i = 0 - while True: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: - self._event_fetch_ongoing -= 1 - return - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 - - self._fetch_event_list(conn, event_list) - - def _fetch_event_list(self, conn, event_list): - """Handle a load of requests from the _event_fetch_list queue - - Args: - conn (twisted.enterprise.adbapi.Connection): database connection - - event_list (list[Tuple[list[str], Deferred]]): - The fetch requests. Each entry consists of a list of event - ids to be fetched, and a deferred to be completed once the - events have been fetched. - - The deferreds are callbacked with a dictionary mapping from event id - to event row. Note that it may well contain additional events that - were not part of this request. - """ - with Measure(self._clock, "_fetch_event_list"): - try: - events_to_fetch = { - event_id for events, _ in event_list for event_id in events - } - - row_dict = self.db.new_transaction( - conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch - ) - - # We only want to resolve deferreds from the main thread - def fire(): - for _, d in event_list: - d.callback(row_dict) - - with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire) - except Exception as e: - logger.exception("do_fetch") - - # We only want to resolve deferreds from the main thread - def fire(evs, exc): - for _, d in evs: - if not d.called: - with PreserveLoggingContext(): - d.errback(exc) - - with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list, e) - - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): - """Fetch a bunch of events from the database. - - Returned events will be added to the cache for future lookups. - - Unknown events are omitted from the response. - - Args: - event_ids (Iterable[str]): The event_ids of the events to fetch - - allow_rejected (bool): Whether to include rejected events. If False, - rejected events are omitted from the response. - - Returns: - Deferred[Dict[str, _EventCacheEntry]]: - map from event id to result. May return extra events which - weren't asked for. - """ - fetched_events = {} - events_to_fetch = event_ids - - while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) - - # we need to recursively fetch any redactions of those events - redaction_ids = set() - for event_id in events_to_fetch: - row = row_map.get(event_id) - fetched_events[event_id] = row - if row: - redaction_ids.update(row["redactions"]) - - events_to_fetch = redaction_ids.difference(fetched_events.keys()) - if events_to_fetch: - logger.debug("Also fetching redaction events %s", events_to_fetch) - - # build a map from event_id to EventBase - event_map = {} - for event_id, row in fetched_events.items(): - if not row: - continue - assert row["event_id"] == event_id - - rejected_reason = row["rejected_reason"] - - if not allow_rejected and rejected_reason: - continue - - d = db_to_json(row["json"]) - internal_metadata = db_to_json(row["internal_metadata"]) - - format_version = row["format_version"] - if format_version is None: - # This means that we stored the event before we had the concept - # of a event format version, so it must be a V1 event. - format_version = EventFormatVersions.V1 - - room_version_id = row["room_version_id"] - - if not room_version_id: - # this should only happen for out-of-band membership events - if not internal_metadata.get("out_of_band_membership"): - logger.warning( - "Room %s for event %s is unknown", d["room_id"], event_id - ) - continue - - # take a wild stab at the room version based on the event format - if format_version == EventFormatVersions.V1: - room_version = RoomVersions.V1 - elif format_version == EventFormatVersions.V2: - room_version = RoomVersions.V3 - else: - room_version = RoomVersions.V5 - else: - room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) - if not room_version: - logger.warning( - "Event %s in room %s has unknown room version %s", - event_id, - d["room_id"], - room_version_id, - ) - continue - - if room_version.event_format != format_version: - logger.error( - "Event %s in room %s with version %s has wrong format: " - "expected %s, was %s", - event_id, - d["room_id"], - room_version_id, - room_version.event_format, - format_version, - ) - continue - - original_ev = make_event_from_dict( - event_dict=d, - room_version=room_version, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - event_map[event_id] = original_ev - - # finally, we can decide whether each one needs redacting, and build - # the cache entries. - result_map = {} - for event_id, original_ev in event_map.items(): - redactions = fetched_events[event_id]["redactions"] - redacted_event = self._maybe_redact_event_row( - original_ev, redactions, event_map - ) - - cache_entry = _EventCacheEntry( - event=original_ev, redacted_event=redacted_event - ) - - self._get_event_cache.prefill((event_id,), cache_entry) - result_map[event_id] = cache_entry - - return result_map - - @defer.inlineCallbacks - def _enqueue_events(self, events): - """Fetches events from the database using the _event_fetch_list. This - allows batch and bulk fetching of events - it allows us to fetch events - without having to create a new transaction for each request for events. - - Args: - events (Iterable[str]): events to be fetched. - - Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. - May contain events that weren't requested. - """ - - events_d = defer.Deferred() - with self._event_fetch_lock: - self._event_fetch_list.append((events, events_d)) - - self._event_fetch_lock.notify() - - if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: - self._event_fetch_ongoing += 1 - should_start = True - else: - should_start = False - - if should_start: - run_as_background_process( - "fetch_events", self.db.runWithConnection, self._do_fetch - ) - - logger.debug("Loading %d events: %s", len(events), events) - with PreserveLoggingContext(): - row_map = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) - - return row_map - - def _fetch_event_rows(self, txn, event_ids): - """Fetch event rows from the database - - Events which are not found are omitted from the result. - - The returned per-event dicts contain the following keys: - - * event_id (str) - - * json (str): json-encoded event structure - - * internal_metadata (str): json-encoded internal metadata dict - - * format_version (int|None): The format of the event. Hopefully one - of EventFormatVersions. 'None' means the event predates - EventFormatVersions (so the event is format V1). - - * room_version_id (str|None): The version of the room which contains the event. - Hopefully one of RoomVersions. - - Due to historical reasons, there may be a few events in the database which - do not have an associated room; in this case None will be returned here. - - * rejected_reason (str|None): if the event was rejected, the reason - why. - - * redactions (List[str]): a list of event-ids which (claim to) redact - this event. - - Args: - txn (twisted.enterprise.adbapi.Connection): - event_ids (Iterable[str]): event IDs to fetch - - Returns: - Dict[str, Dict]: a map from event id to event info. - """ - event_dict = {} - for evs in batch_iter(event_ids, 200): - sql = """\ - SELECT - e.event_id, - e.internal_metadata, - e.json, - e.format_version, - r.room_version, - rej.reason - FROM event_json as e - LEFT JOIN rooms r USING (room_id) - LEFT JOIN rejections as rej USING (event_id) - WHERE """ - - clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", evs - ) - - txn.execute(sql + clause, args) - - for row in txn: - event_id = row[0] - event_dict[event_id] = { - "event_id": event_id, - "internal_metadata": row[1], - "json": row[2], - "format_version": row[3], - "room_version_id": row[4], - "rejected_reason": row[5], - "redactions": [], - } - - # check for redactions - redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " - - clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs) - - txn.execute(redactions_sql + clause, args) - - for (redacter, redacted) in txn: - d = event_dict.get(redacted) - if d: - d["redactions"].append(redacter) - - return event_dict - - def _maybe_redact_event_row(self, original_ev, redactions, event_map): - """Given an event object and a list of possible redacting event ids, - determine whether to honour any of those redactions and if so return a redacted - event. - - Args: - original_ev (EventBase): - redactions (iterable[str]): list of event ids of potential redaction events - event_map (dict[str, EventBase]): other events which have been fetched, in - which we can look up the redaaction events. Map from event id to event. - - Returns: - Deferred[EventBase|None]: if the event should be redacted, a pruned - event object. Otherwise, None. - """ - if original_ev.type == "m.room.create": - # we choose to ignore redactions of m.room.create events. - return None - - for redaction_id in redactions: - redaction_event = event_map.get(redaction_id) - if not redaction_event or redaction_event.rejected_reason: - # we don't have the redaction event, or the redaction event was not - # authorized. - logger.debug( - "%s was redacted by %s but redaction not found/authed", - original_ev.event_id, - redaction_id, - ) - continue - - if redaction_event.room_id != original_ev.room_id: - logger.debug( - "%s was redacted by %s but redaction was in a different room!", - original_ev.event_id, - redaction_id, - ) - continue - - # Starting in room version v3, some redactions need to be - # rechecked if we didn't have the redacted event at the - # time, so we recheck on read instead. - if redaction_event.internal_metadata.need_to_check_redaction(): - expected_domain = get_domain_from_id(original_ev.sender) - if get_domain_from_id(redaction_event.sender) == expected_domain: - # This redaction event is allowed. Mark as not needing a recheck. - redaction_event.internal_metadata.recheck_redaction = False - else: - # Senders don't match, so the event isn't actually redacted - logger.debug( - "%s was redacted by %s but the senders don't match", - original_ev.event_id, - redaction_id, - ) - continue - - logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id) - - # we found a good redaction event. Redact! - redacted_event = prune_event(original_ev) - redacted_event.unsigned["redacted_by"] = redaction_id - - # It's fine to add the event directly, since get_pdu_json - # will serialise this field correctly - redacted_event.unsigned["redacted_because"] = redaction_event - - return redacted_event - - # no valid redaction found for this event - return None - - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): - """Given a list of event ids, check if we have already processed and - stored them as non outliers. - """ - rows = yield self.db.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) - - return {r["event_id"] for r in rows} - - @defer.inlineCallbacks - def have_seen_events(self, event_ids): - """Given a list of event ids, check if we have already processed them. - - Args: - event_ids (iterable[str]): - - Returns: - Deferred[set[str]]: The events we have already seen. - """ - results = set() - - def have_seen_events_txn(txn, chunk): - sql = "SELECT event_id FROM events as e WHERE " - clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", chunk - ) - txn.execute(sql + clause, args) - for (event_id,) in txn: - results.add(event_id) - - # break the input up into chunks of 100 - input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db.runInteraction( - "have_seen_events", have_seen_events_txn, chunk - ) - return results - - def _get_total_state_event_counts_txn(self, txn, room_id): - """ - See get_total_state_event_counts. - """ - # We join against the events table as that has an index on room_id - sql = """ - SELECT COUNT(*) FROM state_events - INNER JOIN events USING (room_id, event_id) - WHERE room_id=? - """ - txn.execute(sql, (room_id,)) - row = txn.fetchone() - return row[0] if row else 0 - - def get_total_state_event_counts(self, room_id): - """ - Gets the total number of state events in a room. - - Args: - room_id (str) - - Returns: - Deferred[int] - """ - return self.db.runInteraction( - "get_total_state_event_counts", - self._get_total_state_event_counts_txn, - room_id, - ) - - def _get_current_state_event_counts_txn(self, txn, room_id): - """ - See get_current_state_event_counts. - """ - sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?" - txn.execute(sql, (room_id,)) - row = txn.fetchone() - return row[0] if row else 0 - - def get_current_state_event_counts(self, room_id): - """ - Gets the current number of state events in a room. - - Args: - room_id (str) - - Returns: - Deferred[int] - """ - return self.db.runInteraction( - "get_current_state_event_counts", - self._get_current_state_event_counts_txn, - room_id, - ) - - @defer.inlineCallbacks - def get_room_complexity(self, room_id): - """ - Get a rough approximation of the complexity of the room. This is used by - remote servers to decide whether they wish to join the room or not. - Higher complexity value indicates that being in the room will consume - more resources. - - Args: - room_id (str) - - Returns: - Deferred[dict[str:int]] of complexity version to complexity. - """ - state_events = yield self.get_current_state_event_counts(room_id) - - # Call this one "v1", so we can introduce new ones as we want to develop - # it. - complexity_v1 = round(state_events / 500, 2) - - return {"v1": complexity_v1} - - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - - def get_current_events_token(self): - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - - def get_all_new_forward_event_rows(self, last_id, current_id, limit): - """Returns new events, for the Events replication stream - - Args: - last_id: the last stream_id from the previous batch. - current_id: the maximum stream_id to return up to - limit: the maximum number of rows to return - - Returns: Deferred[List[Tuple]] - a list of events stream rows. Each tuple consists of a stream id as - the first element, followed by fields suitable for casting into an - EventsStreamRow. - """ - - def get_all_new_forward_event_rows(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows - ) - - def get_ex_outlier_stream_rows(self, last_id, current_id): - """Returns de-outliered events, for the Events replication stream - - Args: - last_id: the last stream_id from the previous batch. - current_id: the maximum stream_id to return up to - - Returns: Deferred[List[Tuple]] - a list of events stream rows. Each tuple consists of a stream id as - the first element, followed by fields suitable for casting into an - EventsStreamRow. - """ - - def get_ex_outlier_stream_rows_txn(txn): - sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering ASC" - ) - - txn.execute(sql, (last_id, current_id)) - return txn.fetchall() - - return self.db.runInteraction( - "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn - ) - - async def get_all_new_backfill_event_rows( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: - """Get updates for backfill replication stream, including all new - backfilled events and events that have gone from being outliers to not. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - if last_id == current_id: - return [], current_id, False - - def get_all_new_backfill_event_rows(txn): - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = [(row[0], row[1:]) for row in txn] - - limited = False - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - limited = True - else: - upper_bound = current_id - - sql = ( - "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend((row[0], row[1:]) for row in txn) - - if len(new_event_updates) >= limit: - upper_bound = new_event_updates[-1][0] - limited = True - - return new_event_updates, upper_bound, limited - - return await self.db.runInteraction( - "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows - ) - - async def get_all_updated_current_state_deltas( - self, from_token: int, to_token: int, target_row_count: int - ) -> Tuple[List[Tuple], int, bool]: - """Fetch updates from current_state_delta_stream - - Args: - from_token: The previous stream token. Updates from this stream id will - be excluded. - - to_token: The current stream token (ie the upper limit). Updates up to this - stream id will be included (modulo the 'limit' param) - - target_row_count: The number of rows to try to return. If more rows are - available, we will set 'limited' in the result. In the event of a large - batch, we may return more rows than this. - Returns: - A triplet `(updates, new_last_token, limited)`, where: - * `updates` is a list of database tuples. - * `new_last_token` is the new position in stream. - * `limited` is whether there are more updates to fetch. - """ - - def get_all_updated_current_state_deltas_txn(txn): - sql = """ - SELECT stream_id, room_id, type, state_key, event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? - """ - txn.execute(sql, (from_token, to_token, target_row_count)) - return txn.fetchall() - - def get_deltas_for_stream_id_txn(txn, stream_id): - sql = """ - SELECT stream_id, room_id, type, state_key, event_id - FROM current_state_delta_stream - WHERE stream_id = ? - """ - txn.execute(sql, [stream_id]) - return txn.fetchall() - - # we need to make sure that, for every stream id in the results, we get *all* - # the rows with that stream id. - - rows = await self.db.runInteraction( - "get_all_updated_current_state_deltas", - get_all_updated_current_state_deltas_txn, - ) # type: List[Tuple] - - # if we've got fewer rows than the limit, we're good - if len(rows) < target_row_count: - return rows, to_token, False - - # we hit the limit, so reduce the upper limit so that we exclude the stream id - # of the last row in the result. - assert rows[-1][0] <= to_token - to_token = rows[-1][0] - 1 - - # search backwards through the list for the point to truncate - for idx in range(len(rows) - 1, 0, -1): - if rows[idx - 1][0] <= to_token: - return rows[:idx], to_token, True - - # bother. We didn't get a full set of changes for even a single - # stream id. let's run the query again, without a row limit, but for - # just one stream id. - to_token += 1 - rows = await self.db.runInteraction( - "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token - ) - - return rows, to_token, True - - @cached(num_args=5, max_entries=10) - def get_all_new_events( - self, - last_backfill_id, - last_forward_id, - current_backfill_id, - current_forward_id, - limit, - ): - """Get all the new events that have arrived at the server either as - new events or as backfilled events""" - have_backfill_events = last_backfill_id != current_backfill_id - have_forward_events = last_forward_id != current_forward_id - - if not have_backfill_events and not have_forward_events: - return defer.succeed(AllNewEventsResult([], [], [], [], [])) - - def get_all_new_events_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - if have_forward_events: - txn.execute(sql, (last_forward_id, current_forward_id, limit)) - new_forward_events = txn.fetchall() - - if len(new_forward_events) == limit: - upper_bound = new_forward_events[-1][0] - else: - upper_bound = current_forward_id - - sql = ( - "SELECT event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - forward_ex_outliers = txn.fetchall() - else: - new_forward_events = [] - forward_ex_outliers = [] - - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - if have_backfill_events: - txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) - new_backfill_events = txn.fetchall() - - if len(new_backfill_events) == limit: - upper_bound = new_backfill_events[-1][0] - else: - upper_bound = current_backfill_id - - sql = ( - "SELECT -event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_backfill_id, -upper_bound)) - backward_ex_outliers = txn.fetchall() - else: - new_backfill_events = [] - backward_ex_outliers = [] - - return AllNewEventsResult( - new_forward_events, - new_backfill_events, - forward_ex_outliers, - backward_ex_outliers, - ) - - return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) - - async def is_event_after(self, event_id1, event_id2): - """Returns True if event_id1 is after event_id2 in the stream - """ - to_1, so_1 = await self.get_event_ordering(event_id1) - to_2, so_2 = await self.get_event_ordering(event_id2) - return (to_1, so_1) > (to_2, so_2) - - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db.simple_select_one( - table="events", - retcols=["topological_ordering", "stream_ordering"], - keyvalues={"event_id": event_id}, - allow_none=True, - ) - - if not res: - raise SynapseError(404, "Could not find event %s" % (event_id,)) - - return (int(res["topological_ordering"]), int(res["stream_ordering"])) - - def get_next_event_to_expire(self): - """Retrieve the entry with the lowest expiry timestamp in the event_expiry - table, or None if there's no more event to expire. - - Returns: Deferred[Optional[Tuple[str, int]]] - A tuple containing the event ID as its first element and an expiry timestamp - as its second one, if there's at least one row in the event_expiry table. - None otherwise. - """ - - def get_next_event_to_expire_txn(txn): - txn.execute( - """ - SELECT event_id, expiry_ts FROM event_expiry - ORDER BY expiry_ts ASC LIMIT 1 - """ - ) - - return txn.fetchone() - - return self.db.runInteraction( - desc="get_next_event_to_expire", func=get_next_event_to_expire_txn - ) - - @cached(tree=True, cache_context=True) - async def get_unread_message_count_for_user( - self, room_id: str, user_id: str, cache_context: _CacheContext, - ) -> int: - """Retrieve the count of unread messages for the given room and user. - - Args: - room_id: The ID of the room to count unread messages in. - user_id: The ID of the user to count unread messages for. - - Returns: - The number of unread messages for the given user in the given room. - """ - with Measure(self._clock, "get_unread_message_count_for_user"): - last_read_event_id = await self.get_last_receipt_event_id_for_user( - user_id=user_id, - room_id=room_id, - receipt_type="m.read", - on_invalidate=cache_context.invalidate, - ) - - return await self.db.runInteraction( - "get_unread_message_count_for_user", - self._get_unread_message_count_for_user_txn, - user_id, - room_id, - last_read_event_id, - ) - - def _get_unread_message_count_for_user_txn( - self, - txn: Cursor, - user_id: str, - room_id: str, - last_read_event_id: Optional[str], - ) -> int: - if last_read_event_id: - # Get the stream ordering for the last read event. - stream_ordering = self.db.simple_select_one_onecol_txn( - txn=txn, - table="events", - keyvalues={"room_id": room_id, "event_id": last_read_event_id}, - retcol="stream_ordering", - ) - else: - # If there's no read receipt for that room, it probably means the user hasn't - # opened it yet, in which case use the stream ID of their join event. - # We can't just set it to 0 otherwise messages from other local users from - # before this user joined will be counted as well. - txn.execute( - """ - SELECT stream_ordering FROM local_current_membership - LEFT JOIN events USING (event_id, room_id) - WHERE membership = 'join' - AND user_id = ? - AND room_id = ? - """, - (user_id, room_id), - ) - row = txn.fetchone() - - if row is None: - return 0 - - stream_ordering = row[0] - - # Count the messages that qualify as unread after the stream ordering we've just - # retrieved. - sql = """ - SELECT COUNT(*) FROM events - WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread - """ - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - - return row[0] if row else 0 - - -AllNewEventsResult = namedtuple( - "AllNewEventsResult", - [ - "new_forward_events", - "new_backfill_events", - "forward_ex_outliers", - "backward_ex_outliers", - ], -) diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py deleted file mode 100644 index 342d6622a4..0000000000 --- a/synapse/storage/data_stores/main/filtering.py +++ /dev/null @@ -1,74 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from canonicaljson import encode_canonical_json - -from synapse.api.errors import Codes, SynapseError -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cachedInlineCallbacks - - -class FilteringStore(SQLBaseStore): - @cachedInlineCallbacks(num_args=2) - def get_user_filter(self, user_localpart, filter_id): - # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail - # with a coherent error message rather than 500 M_UNKNOWN. - try: - int(filter_id) - except ValueError: - raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - - def_json = yield self.db.simple_select_one_onecol( - table="user_filters", - keyvalues={"user_id": user_localpart, "filter_id": filter_id}, - retcol="filter_json", - allow_none=False, - desc="get_user_filter", - ) - - return db_to_json(def_json) - - def add_user_filter(self, user_localpart, user_filter): - def_json = encode_canonical_json(user_filter) - - # Need an atomic transaction to SELECT the maximal ID so far then - # INSERT a new one - def _do_txn(txn): - sql = ( - "SELECT filter_id FROM user_filters " - "WHERE user_id = ? AND filter_json = ?" - ) - txn.execute(sql, (user_localpart, bytearray(def_json))) - filter_id_response = txn.fetchone() - if filter_id_response is not None: - return filter_id_response[0] - - sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" - txn.execute(sql, (user_localpart,)) - max_id = txn.fetchone()[0] - if max_id is None: - filter_id = 0 - else: - filter_id = max_id + 1 - - sql = ( - "INSERT INTO user_filters (user_id, filter_id, filter_json)" - "VALUES(?, ?, ?)" - ) - txn.execute(sql, (user_localpart, filter_id, bytearray(def_json))) - - return filter_id - - return self.db.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py deleted file mode 100644 index 01ff561e1a..0000000000 --- a/synapse/storage/data_stores/main/group_server.py +++ /dev/null @@ -1,1295 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore, db_to_json - -# The category ID for the "default" category. We don't store as null in the -# database to avoid the fun of null != null -_DEFAULT_CATEGORY_ID = "" -_DEFAULT_ROLE_ID = "" - - -class GroupServerWorkerStore(SQLBaseStore): - def get_group(self, group_id): - return self.db.simple_select_one( - table="groups", - keyvalues={"group_id": group_id}, - retcols=( - "name", - "short_description", - "long_description", - "avatar_url", - "is_public", - "join_policy", - ), - allow_none=True, - desc="get_group", - ) - - def get_users_in_group(self, group_id, include_private=False): - # TODO: Pagination - - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - return self.db.simple_select_list( - table="group_users", - keyvalues=keyvalues, - retcols=("user_id", "is_public", "is_admin"), - desc="get_users_in_group", - ) - - def get_invited_users_in_group(self, group_id): - # TODO: Pagination - - return self.db.simple_select_onecol( - table="group_invites", - keyvalues={"group_id": group_id}, - retcol="user_id", - desc="get_invited_users_in_group", - ) - - def get_rooms_in_group(self, group_id: str, include_private: bool = False): - """Retrieve the rooms that belong to a given group. Does not return rooms that - lack members. - - Args: - group_id: The ID of the group to query for rooms - include_private: Whether to return private rooms in results - - Returns: - Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the - form of: - - { - "room_id": "!a_room_id:example.com", # The ID of the room - "is_public": False # Whether this is a public room or not - } - """ - # TODO: Pagination - - def _get_rooms_in_group_txn(txn): - sql = """ - SELECT room_id, is_public FROM group_rooms - WHERE group_id = ? - AND room_id IN ( - SELECT group_rooms.room_id FROM group_rooms - LEFT JOIN room_stats_current ON - group_rooms.room_id = room_stats_current.room_id - AND joined_members > 0 - AND local_users_in_room > 0 - LEFT JOIN rooms ON - group_rooms.room_id = rooms.room_id - AND (room_version <> '') = ? - ) - """ - args = [group_id, False] - - if not include_private: - sql += " AND is_public = ?" - args += [True] - - txn.execute(sql, args) - - return [ - {"room_id": room_id, "is_public": is_public} - for room_id, is_public in txn - ] - - return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn) - - def get_rooms_for_summary_by_category( - self, group_id: str, include_private: bool = False, - ): - """Get the rooms and categories that should be included in a summary request - - Args: - group_id: The ID of the group to query the summary for - include_private: Whether to return private rooms in results - - Returns: - Deferred[Tuple[List, Dict]]: A tuple containing: - - * A list of dictionaries with the keys: - * "room_id": str, the room ID - * "is_public": bool, whether the room is public - * "category_id": str|None, the category ID if set, else None - * "order": int, the sort order of rooms - - * A dictionary with the key: - * category_id (str): a dictionary with the keys: - * "is_public": bool, whether the category is public - * "profile": str, the category profile - * "order": int, the sort order of rooms in this category - """ - - def _get_rooms_for_summary_txn(txn): - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - sql = """ - SELECT room_id, is_public, category_id, room_order - FROM group_summary_rooms - WHERE group_id = ? - AND room_id IN ( - SELECT group_rooms.room_id FROM group_rooms - LEFT JOIN room_stats_current ON - group_rooms.room_id = room_stats_current.room_id - AND joined_members > 0 - AND local_users_in_room > 0 - LEFT JOIN rooms ON - group_rooms.room_id = rooms.room_id - AND (room_version <> '') = ? - ) - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, False, True)) - else: - txn.execute(sql, (group_id, False)) - - rooms = [ - { - "room_id": row[0], - "is_public": row[1], - "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, - "order": row[3], - } - for row in txn - ] - - sql = """ - SELECT category_id, is_public, profile, cat_order - FROM group_summary_room_categories - INNER JOIN group_room_categories USING (group_id, category_id) - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - categories = { - row[0]: { - "is_public": row[1], - "profile": db_to_json(row[2]), - "order": row[3], - } - for row in txn - } - - return rooms, categories - - return self.db.runInteraction( - "get_rooms_for_summary", _get_rooms_for_summary_txn - ) - - @defer.inlineCallbacks - def get_group_categories(self, group_id): - rows = yield self.db.simple_select_list( - table="group_room_categories", - keyvalues={"group_id": group_id}, - retcols=("category_id", "is_public", "profile"), - desc="get_group_categories", - ) - - return { - row["category_id"]: { - "is_public": row["is_public"], - "profile": db_to_json(row["profile"]), - } - for row in rows - } - - @defer.inlineCallbacks - def get_group_category(self, group_id, category_id): - category = yield self.db.simple_select_one( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcols=("is_public", "profile"), - desc="get_group_category", - ) - - category["profile"] = db_to_json(category["profile"]) - - return category - - @defer.inlineCallbacks - def get_group_roles(self, group_id): - rows = yield self.db.simple_select_list( - table="group_roles", - keyvalues={"group_id": group_id}, - retcols=("role_id", "is_public", "profile"), - desc="get_group_roles", - ) - - return { - row["role_id"]: { - "is_public": row["is_public"], - "profile": db_to_json(row["profile"]), - } - for row in rows - } - - @defer.inlineCallbacks - def get_group_role(self, group_id, role_id): - role = yield self.db.simple_select_one( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcols=("is_public", "profile"), - desc="get_group_role", - ) - - role["profile"] = db_to_json(role["profile"]) - - return role - - def get_local_groups_for_room(self, room_id): - """Get all of the local group that contain a given room - Args: - room_id (str): The ID of a room - Returns: - Deferred[list[str]]: A twisted.Deferred containing a list of group ids - containing this room - """ - return self.db.simple_select_onecol( - table="group_rooms", - keyvalues={"room_id": room_id}, - retcol="group_id", - desc="get_local_groups_for_room", - ) - - def get_users_for_summary_by_role(self, group_id, include_private=False): - """Get the users and roles that should be included in a summary request - - Returns ([users], [roles]) - """ - - def _get_users_for_summary_txn(txn): - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - sql = """ - SELECT user_id, is_public, role_id, user_order - FROM group_summary_users - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - users = [ - { - "user_id": row[0], - "is_public": row[1], - "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, - "order": row[3], - } - for row in txn - ] - - sql = """ - SELECT role_id, is_public, profile, role_order - FROM group_summary_roles - INNER JOIN group_roles USING (group_id, role_id) - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - roles = { - row[0]: { - "is_public": row[1], - "profile": db_to_json(row[2]), - "order": row[3], - } - for row in txn - } - - return users, roles - - return self.db.runInteraction( - "get_users_for_summary_by_role", _get_users_for_summary_txn - ) - - def is_user_in_group(self, user_id, group_id): - return self.db.simple_select_one_onecol( - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="is_user_in_group", - ).addCallback(lambda r: bool(r)) - - def is_user_admin_in_group(self, group_id, user_id): - return self.db.simple_select_one_onecol( - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="is_admin", - allow_none=True, - desc="is_user_admin_in_group", - ) - - def is_user_invited_to_local_group(self, group_id, user_id): - """Has the group server invited a user? - """ - return self.db.simple_select_one_onecol( - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - desc="is_user_invited_to_local_group", - allow_none=True, - ) - - def get_users_membership_info_in_group(self, group_id, user_id): - """Get a dict describing the membership of a user in a group. - - Example if joined: - - { - "membership": "join", - "is_public": True, - "is_privileged": False, - } - - Returns an empty dict if the user is not join/invite/etc - """ - - def _get_users_membership_in_group_txn(txn): - row = self.db.simple_select_one_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcols=("is_admin", "is_public"), - allow_none=True, - ) - - if row: - return { - "membership": "join", - "is_public": row["is_public"], - "is_privileged": row["is_admin"], - } - - row = self.db.simple_select_one_onecol_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - ) - - if row: - return {"membership": "invite"} - - return {} - - return self.db.runInteraction( - "get_users_membership_info_in_group", _get_users_membership_in_group_txn - ) - - def get_publicised_groups_for_user(self, user_id): - """Get all groups a user is publicising - """ - return self.db.simple_select_onecol( - table="local_group_membership", - keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, - retcol="group_id", - desc="get_publicised_groups_for_user", - ) - - def get_attestations_need_renewals(self, valid_until_ms): - """Get all attestations that need to be renewed until givent time - """ - - def _get_attestations_need_renewals_txn(txn): - sql = """ - SELECT group_id, user_id FROM group_attestations_renewals - WHERE valid_until_ms <= ? - """ - txn.execute(sql, (valid_until_ms,)) - return self.db.cursor_to_dict(txn) - - return self.db.runInteraction( - "get_attestations_need_renewals", _get_attestations_need_renewals_txn - ) - - @defer.inlineCallbacks - def get_remote_attestation(self, group_id, user_id): - """Get the attestation that proves the remote agrees that the user is - in the group. - """ - row = yield self.db.simple_select_one( - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcols=("valid_until_ms", "attestation_json"), - desc="get_remote_attestation", - allow_none=True, - ) - - now = int(self._clock.time_msec()) - if row and now < row["valid_until_ms"]: - return db_to_json(row["attestation_json"]) - - return None - - def get_joined_groups(self, user_id): - return self.db.simple_select_onecol( - table="local_group_membership", - keyvalues={"user_id": user_id, "membership": "join"}, - retcol="group_id", - desc="get_joined_groups", - ) - - def get_all_groups_for_user(self, user_id, now_token): - def _get_all_groups_for_user_txn(txn): - sql = """ - SELECT group_id, type, membership, u.content - FROM local_group_updates AS u - INNER JOIN local_group_membership USING (group_id, user_id) - WHERE user_id = ? AND membership != 'leave' - AND stream_id <= ? - """ - txn.execute(sql, (user_id, now_token)) - return [ - { - "group_id": row[0], - "type": row[1], - "membership": row[2], - "content": db_to_json(row[3]), - } - for row in txn - ] - - return self.db.runInteraction( - "get_all_groups_for_user", _get_all_groups_for_user_txn - ) - - def get_groups_changes_for_user(self, user_id, from_token, to_token): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_entity_changed( - user_id, from_token - ) - if not has_changed: - return defer.succeed([]) - - def _get_groups_changes_for_user_txn(txn): - sql = """ - SELECT group_id, membership, type, u.content - FROM local_group_updates AS u - INNER JOIN local_group_membership USING (group_id, user_id) - WHERE user_id = ? AND ? < stream_id AND stream_id <= ? - """ - txn.execute(sql, (user_id, from_token, to_token)) - return [ - { - "group_id": group_id, - "membership": membership, - "type": gtype, - "content": db_to_json(content_json), - } - for group_id, membership, gtype, content_json in txn - ] - - return self.db.runInteraction( - "get_groups_changes_for_user", _get_groups_changes_for_user_txn - ) - - async def get_all_groups_changes( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for groups replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - last_id = int(last_id) - has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) - - if not has_changed: - return [], current_id, False - - def _get_all_groups_changes_txn(txn): - sql = """ - SELECT stream_id, group_id, user_id, type, content - FROM local_group_updates - WHERE ? < stream_id AND stream_id <= ? - LIMIT ? - """ - txn.execute(sql, (last_id, current_id, limit)) - updates = [ - (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) - for stream_id, group_id, user_id, gtype, content_json in txn - ] - - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db.runInteraction( - "get_all_groups_changes", _get_all_groups_changes_txn - ) - - -class GroupServerStore(GroupServerWorkerStore): - def set_group_join_policy(self, group_id, join_policy): - """Set the join policy of a group. - - join_policy can be one of: - * "invite" - * "open" - """ - return self.db.simple_update_one( - table="groups", - keyvalues={"group_id": group_id}, - updatevalues={"join_policy": join_policy}, - desc="set_group_join_policy", - ) - - def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.db.runInteraction( - "add_room_to_summary", - self._add_room_to_summary_txn, - group_id, - room_id, - category_id, - order, - is_public, - ) - - def _add_room_to_summary_txn( - self, txn, group_id, room_id, category_id, order, is_public - ): - """Add (or update) room's entry in summary. - - Args: - group_id (str) - room_id (str) - category_id (str): If not None then adds the category to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the room at that position, e.g. - an order of 1 will put the room first. Otherwise, the room gets - added to the end. - """ - room_in_group = self.db.simple_select_one_onecol_txn( - txn, - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - retcol="room_id", - allow_none=True, - ) - if not room_in_group: - raise SynapseError(400, "room not in group") - - if category_id is None: - category_id = _DEFAULT_CATEGORY_ID - else: - cat_exists = self.db.simple_select_one_onecol_txn( - txn, - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcol="group_id", - allow_none=True, - ) - if not cat_exists: - raise SynapseError(400, "Category doesn't exist") - - # TODO: Check category is part of summary already - cat_exists = self.db.simple_select_one_onecol_txn( - txn, - table="group_summary_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcol="group_id", - allow_none=True, - ) - if not cat_exists: - # If not, add it with an order larger than all others - txn.execute( - """ - INSERT INTO group_summary_room_categories - (group_id, category_id, cat_order) - SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 - FROM group_summary_room_categories - WHERE group_id = ? AND category_id = ? - """, - (group_id, category_id, group_id, category_id), - ) - - existing = self.db.simple_select_one_txn( - txn, - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "room_id": room_id, - "category_id": category_id, - }, - retcols=("room_order", "is_public"), - allow_none=True, - ) - - if order is not None: - # Shuffle other room orders that come after the given order - sql = """ - UPDATE group_summary_rooms SET room_order = room_order + 1 - WHERE group_id = ? AND category_id = ? AND room_order >= ? - """ - txn.execute(sql, (group_id, category_id, order)) - elif not existing: - sql = """ - SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms - WHERE group_id = ? AND category_id = ? - """ - txn.execute(sql, (group_id, category_id)) - (order,) = txn.fetchone() - - if existing: - to_update = {} - if order is not None: - to_update["room_order"] = order - if is_public is not None: - to_update["is_public"] = is_public - self.db.simple_update_txn( - txn, - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - }, - values=to_update, - ) - else: - if is_public is None: - is_public = True - - self.db.simple_insert_txn( - txn, - table="group_summary_rooms", - values={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - "room_order": order, - "is_public": is_public, - }, - ) - - def remove_room_from_summary(self, group_id, room_id, category_id): - if category_id is None: - category_id = _DEFAULT_CATEGORY_ID - - return self.db.simple_delete( - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - }, - desc="remove_room_from_summary", - ) - - def upsert_group_category(self, group_id, category_id, profile, is_public): - """Add/update room category for group - """ - insertion_values = {} - update_values = {"category_id": category_id} # This cannot be empty - - if profile is None: - insertion_values["profile"] = "{}" - else: - update_values["profile"] = json.dumps(profile) - - if is_public is None: - insertion_values["is_public"] = True - else: - update_values["is_public"] = is_public - - return self.db.simple_upsert( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - values=update_values, - insertion_values=insertion_values, - desc="upsert_group_category", - ) - - def remove_group_category(self, group_id, category_id): - return self.db.simple_delete( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - desc="remove_group_category", - ) - - def upsert_group_role(self, group_id, role_id, profile, is_public): - """Add/remove user role - """ - insertion_values = {} - update_values = {"role_id": role_id} # This cannot be empty - - if profile is None: - insertion_values["profile"] = "{}" - else: - update_values["profile"] = json.dumps(profile) - - if is_public is None: - insertion_values["is_public"] = True - else: - update_values["is_public"] = is_public - - return self.db.simple_upsert( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - values=update_values, - insertion_values=insertion_values, - desc="upsert_group_role", - ) - - def remove_group_role(self, group_id, role_id): - return self.db.simple_delete( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - desc="remove_group_role", - ) - - def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.db.runInteraction( - "add_user_to_summary", - self._add_user_to_summary_txn, - group_id, - user_id, - role_id, - order, - is_public, - ) - - def _add_user_to_summary_txn( - self, txn, group_id, user_id, role_id, order, is_public - ): - """Add (or update) user's entry in summary. - - Args: - group_id (str) - user_id (str) - role_id (str): If not None then adds the role to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the user at that position, e.g. - an order of 1 will put the user first. Otherwise, the user gets - added to the end. - """ - user_in_group = self.db.simple_select_one_onecol_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - ) - if not user_in_group: - raise SynapseError(400, "user not in group") - - if role_id is None: - role_id = _DEFAULT_ROLE_ID - else: - role_exists = self.db.simple_select_one_onecol_txn( - txn, - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcol="group_id", - allow_none=True, - ) - if not role_exists: - raise SynapseError(400, "Role doesn't exist") - - # TODO: Check role is part of the summary already - role_exists = self.db.simple_select_one_onecol_txn( - txn, - table="group_summary_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcol="group_id", - allow_none=True, - ) - if not role_exists: - # If not, add it with an order larger than all others - txn.execute( - """ - INSERT INTO group_summary_roles - (group_id, role_id, role_order) - SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 - FROM group_summary_roles - WHERE group_id = ? AND role_id = ? - """, - (group_id, role_id, group_id, role_id), - ) - - existing = self.db.simple_select_one_txn( - txn, - table="group_summary_users", - keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, - retcols=("user_order", "is_public"), - allow_none=True, - ) - - if order is not None: - # Shuffle other users orders that come after the given order - sql = """ - UPDATE group_summary_users SET user_order = user_order + 1 - WHERE group_id = ? AND role_id = ? AND user_order >= ? - """ - txn.execute(sql, (group_id, role_id, order)) - elif not existing: - sql = """ - SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users - WHERE group_id = ? AND role_id = ? - """ - txn.execute(sql, (group_id, role_id)) - (order,) = txn.fetchone() - - if existing: - to_update = {} - if order is not None: - to_update["user_order"] = order - if is_public is not None: - to_update["is_public"] = is_public - self.db.simple_update_txn( - txn, - table="group_summary_users", - keyvalues={ - "group_id": group_id, - "role_id": role_id, - "user_id": user_id, - }, - values=to_update, - ) - else: - if is_public is None: - is_public = True - - self.db.simple_insert_txn( - txn, - table="group_summary_users", - values={ - "group_id": group_id, - "role_id": role_id, - "user_id": user_id, - "user_order": order, - "is_public": is_public, - }, - ) - - def remove_user_from_summary(self, group_id, user_id, role_id): - if role_id is None: - role_id = _DEFAULT_ROLE_ID - - return self.db.simple_delete( - table="group_summary_users", - keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, - desc="remove_user_from_summary", - ) - - def add_group_invite(self, group_id, user_id): - """Record that the group server has invited a user - """ - return self.db.simple_insert( - table="group_invites", - values={"group_id": group_id, "user_id": user_id}, - desc="add_group_invite", - ) - - def add_user_to_group( - self, - group_id, - user_id, - is_admin=False, - is_public=True, - local_attestation=None, - remote_attestation=None, - ): - """Add a user to the group server. - - Args: - group_id (str) - user_id (str) - is_admin (bool) - is_public (bool) - local_attestation (dict): The attestation the GS created to give - to the remote server. Optional if the user and group are on the - same server - remote_attestation (dict): The attestation given to GS by remote - server. Optional if the user and group are on the same server - """ - - def _add_user_to_group_txn(txn): - self.db.simple_insert_txn( - txn, - table="group_users", - values={ - "group_id": group_id, - "user_id": user_id, - "is_admin": is_admin, - "is_public": is_public, - }, - ) - - self.db.simple_delete_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - if local_attestation: - self.db.simple_insert_txn( - txn, - table="group_attestations_renewals", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": local_attestation["valid_until_ms"], - }, - ) - if remote_attestation: - self.db.simple_insert_txn( - txn, - table="group_attestations_remote", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), - }, - ) - - return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn) - - def remove_user_from_group(self, group_id, user_id): - def _remove_user_from_group_txn(txn): - self.db.simple_delete_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db.simple_delete_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db.simple_delete_txn( - txn, - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db.simple_delete_txn( - txn, - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db.simple_delete_txn( - txn, - table="group_summary_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - return self.db.runInteraction( - "remove_user_from_group", _remove_user_from_group_txn - ) - - def add_room_to_group(self, group_id, room_id, is_public): - return self.db.simple_insert( - table="group_rooms", - values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, - desc="add_room_to_group", - ) - - def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self.db.simple_update( - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - updatevalues={"is_public": is_public}, - desc="update_room_in_group_visibility", - ) - - def remove_room_from_group(self, group_id, room_id): - def _remove_room_from_group_txn(txn): - self.db.simple_delete_txn( - txn, - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - ) - - self.db.simple_delete_txn( - txn, - table="group_summary_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - ) - - return self.db.runInteraction( - "remove_room_from_group", _remove_room_from_group_txn - ) - - def update_group_publicity(self, group_id, user_id, publicise): - """Update whether the user is publicising their membership of the group - """ - return self.db.simple_update_one( - table="local_group_membership", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={"is_publicised": publicise}, - desc="update_group_publicity", - ) - - @defer.inlineCallbacks - def register_user_group_membership( - self, - group_id, - user_id, - membership, - is_admin=False, - content={}, - local_attestation=None, - remote_attestation=None, - is_publicised=False, - ): - """Registers that a local user is a member of a (local or remote) group. - - Args: - group_id (str) - user_id (str) - membership (str) - is_admin (bool) - content (dict): Content of the membership, e.g. includes the inviter - if the user has been invited. - local_attestation (dict): If remote group then store the fact that we - have given out an attestation, else None. - remote_attestation (dict): If remote group then store the remote - attestation from the group, else None. - """ - - def _register_user_group_membership_txn(txn, next_id): - # TODO: Upsert? - self.db.simple_delete_txn( - txn, - table="local_group_membership", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db.simple_insert_txn( - txn, - table="local_group_membership", - values={ - "group_id": group_id, - "user_id": user_id, - "is_admin": is_admin, - "membership": membership, - "is_publicised": is_publicised, - "content": json.dumps(content), - }, - ) - - self.db.simple_insert_txn( - txn, - table="local_group_updates", - values={ - "stream_id": next_id, - "group_id": group_id, - "user_id": user_id, - "type": "membership", - "content": json.dumps( - {"membership": membership, "content": content} - ), - }, - ) - self._group_updates_stream_cache.entity_has_changed(user_id, next_id) - - # TODO: Insert profile to ensure it comes down stream if its a join. - - if membership == "join": - if local_attestation: - self.db.simple_insert_txn( - txn, - table="group_attestations_renewals", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": local_attestation["valid_until_ms"], - }, - ) - if remote_attestation: - self.db.simple_insert_txn( - txn, - table="group_attestations_remote", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), - }, - ) - else: - self.db.simple_delete_txn( - txn, - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db.simple_delete_txn( - txn, - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - return next_id - - with self._group_updates_id_gen.get_next() as next_id: - res = yield self.db.runInteraction( - "register_user_group_membership", - _register_user_group_membership_txn, - next_id, - ) - return res - - @defer.inlineCallbacks - def create_group( - self, group_id, user_id, name, avatar_url, short_description, long_description - ): - yield self.db.simple_insert( - table="groups", - values={ - "group_id": group_id, - "name": name, - "avatar_url": avatar_url, - "short_description": short_description, - "long_description": long_description, - "is_public": True, - }, - desc="create_group", - ) - - @defer.inlineCallbacks - def update_group_profile(self, group_id, profile): - yield self.db.simple_update_one( - table="groups", - keyvalues={"group_id": group_id}, - updatevalues=profile, - desc="update_group_profile", - ) - - def update_attestation_renewal(self, group_id, user_id, attestation): - """Update an attestation that we have renewed - """ - return self.db.simple_update_one( - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, - desc="update_attestation_renewal", - ) - - def update_remote_attestion(self, group_id, user_id, attestation): - """Update an attestation that a remote has renewed - """ - return self.db.simple_update_one( - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={ - "valid_until_ms": attestation["valid_until_ms"], - "attestation_json": json.dumps(attestation), - }, - desc="update_remote_attestion", - ) - - def remove_attestation_renewal(self, group_id, user_id): - """Remove an attestation that we thought we should renew, but actually - shouldn't. Ideally this would never get called as we would never - incorrectly try and do attestations for local users on local groups. - - Args: - group_id (str) - user_id (str) - """ - return self.db.simple_delete( - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - desc="remove_attestation_renewal", - ) - - def get_group_stream_token(self): - return self._group_updates_id_gen.get_current_token() - - def delete_group(self, group_id): - """Deletes a group fully from the database. - - Args: - group_id (str) - - Returns: - Deferred - """ - - def _delete_group_txn(txn): - tables = [ - "groups", - "group_users", - "group_invites", - "group_rooms", - "group_summary_rooms", - "group_summary_room_categories", - "group_room_categories", - "group_summary_users", - "group_summary_roles", - "group_roles", - "group_attestations_renewals", - "group_attestations_remote", - ] - - for table in tables: - self.db.simple_delete_txn( - txn, table=table, keyvalues={"group_id": group_id} - ) - - return self.db.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py deleted file mode 100644 index 4e1642a27a..0000000000 --- a/synapse/storage/data_stores/main/keys.py +++ /dev/null @@ -1,208 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import logging - -from signedjson.key import decode_verify_key_bytes - -from synapse.storage._base import SQLBaseStore -from synapse.storage.keys import FetchKeyResult -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.iterutils import batch_iter - -logger = logging.getLogger(__name__) - - -db_binary_type = memoryview - - -class KeyStore(SQLBaseStore): - """Persistence for signature verification keys - """ - - @cached() - def _get_server_verify_key(self, server_name_and_key_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" - ) - def get_server_verify_keys(self, server_name_and_key_ids): - """ - Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): - iterable of (server_name, key-id) tuples to fetch keys for - - Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown - """ - keys = {} - - def _get_keys(txn, batch): - """Processes a batch of keys to fetch, and adds the result to `keys`.""" - - # batch_iter always returns tuples so it's safe to do len(batch) - sql = ( - "SELECT server_name, key_id, verify_key, ts_valid_until_ms " - "FROM server_signature_keys WHERE 1=0" - ) + " OR (server_name=? AND key_id=?)" * len(batch) - - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - - for row in txn: - server_name, key_id, key_bytes, ts_valid_until_ms = row - - if ts_valid_until_ms is None: - # Old keys may be stored with a ts_valid_until_ms of null, - # in which case we treat this as if it was set to `0`, i.e. - # it won't match key requests that define a minimum - # `ts_valid_until_ms`. - ts_valid_until_ms = 0 - - res = FetchKeyResult( - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), - valid_until_ts=ts_valid_until_ms, - ) - keys[(server_name, key_id)] = res - - def _txn(txn): - for batch in batch_iter(server_name_and_key_ids, 50): - _get_keys(txn, batch) - return keys - - return self.db.runInteraction("get_server_verify_keys", _txn) - - def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): - """Stores NACL verification keys for remote servers. - Args: - from_server (str): Where the verification keys were looked up - ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, FetchKeyResult]]): - keys to be stored. Each entry is a triplet of - (server_name, key_id, key). - """ - key_values = [] - value_values = [] - invalidations = [] - for server_name, key_id, fetch_result in verify_keys: - key_values.append((server_name, key_id)) - value_values.append( - ( - from_server, - ts_added_ms, - fetch_result.valid_until_ts, - db_binary_type(fetch_result.verify_key.encode()), - ) - ) - # invalidate takes a tuple corresponding to the params of - # _get_server_verify_key. _get_server_verify_key only takes one - # param, which is itself the 2-tuple (server_name, key_id). - invalidations.append((server_name, key_id)) - - def _invalidate(res): - f = self._get_server_verify_key.invalidate - for i in invalidations: - f((i,)) - return res - - return self.db.runInteraction( - "store_server_verify_keys", - self.db.simple_upsert_many_txn, - table="server_signature_keys", - key_names=("server_name", "key_id"), - key_values=key_values, - value_names=( - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "verify_key", - ), - value_values=value_values, - ).addCallback(_invalidate) - - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): - """Stores the JSON bytes for a set of keys from a server - The JSON should be signed by the originating server, the intermediate - server, and by this server. Updates the value for the - (server_name, key_id, from_server) triplet if one already existed. - Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. - """ - return self.db.simple_upsert( - table="server_keys_json", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - }, - values={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - "ts_added_ms": ts_now_ms, - "ts_valid_until_ms": ts_expires_ms, - "key_json": db_binary_type(key_json_bytes), - }, - desc="store_server_keys_json", - ) - - def get_server_keys_json(self, server_keys): - """Retrive the key json for a list of server_keys and key ids. - If no keys are found for a given server, key_id and source then - that server, key_id, and source triplet entry will be an empty list. - The JSON is returned as a byte array so that it can be efficiently - used in an HTTP response. - Args: - server_keys (list): List of (server_name, key_id, source) triplets. - Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts - """ - - def _get_server_keys_json_txn(txn): - results = {} - for server_name, key_id, from_server in server_keys: - keyvalues = {"server_name": server_name} - if key_id is not None: - keyvalues["key_id"] = key_id - if from_server is not None: - keyvalues["from_server"] = from_server - rows = self.db.simple_select_list_txn( - txn, - "server_keys_json", - keyvalues=keyvalues, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", - ), - ) - results[(server_name, key_id, from_server)] = rows - return results - - return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py deleted file mode 100644 index 15bc13cbd0..0000000000 --- a/synapse/storage/data_stores/main/media_repository.py +++ /dev/null @@ -1,394 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database - - -class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__( - database, db_conn, hs - ) - - self.db.updates.register_background_index_update( - update_name="local_media_repository_url_idx", - index_name="local_media_repository_url_idx", - table="local_media_repository", - columns=["created_ts"], - where_clause="url_cache IS NOT NULL", - ) - - -class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): - """Persistence for attachments and avatars""" - - def __init__(self, database: Database, db_conn, hs): - super(MediaRepositoryStore, self).__init__(database, db_conn, hs) - - def get_local_media(self, media_id): - """Get the metadata for a local piece of media - Returns: - None if the media_id doesn't exist. - """ - return self.db.simple_select_one( - "local_media_repository", - {"media_id": media_id}, - ( - "media_type", - "media_length", - "upload_name", - "created_ts", - "quarantined_by", - "url_cache", - ), - allow_none=True, - desc="get_local_media", - ) - - def store_local_media( - self, - media_id, - media_type, - time_now_ms, - upload_name, - media_length, - user_id, - url_cache=None, - ): - return self.db.simple_insert( - "local_media_repository", - { - "media_id": media_id, - "media_type": media_type, - "created_ts": time_now_ms, - "upload_name": upload_name, - "media_length": media_length, - "user_id": user_id.to_string(), - "url_cache": url_cache, - }, - desc="store_local_media", - ) - - def mark_local_media_as_safe(self, media_id: str): - """Mark a local media as safe from quarantining.""" - return self.db.simple_update_one( - table="local_media_repository", - keyvalues={"media_id": media_id}, - updatevalues={"safe_from_quarantine": True}, - desc="mark_local_media_as_safe", - ) - - def get_url_cache(self, url, ts): - """Get the media_id and ts for a cached URL as of the given timestamp - Returns: - None if the URL isn't cached. - """ - - def get_url_cache_txn(txn): - # get the most recently cached result (relative to the given ts) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts <= ?" - " ORDER BY download_ts DESC LIMIT 1" - ) - txn.execute(sql, (url, ts)) - row = txn.fetchone() - - if not row: - # ...or if we've requested a timestamp older than the oldest - # copy in the cache, return the oldest copy (if any) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts > ?" - " ORDER BY download_ts ASC LIMIT 1" - ) - txn.execute(sql, (url, ts)) - row = txn.fetchone() - - if not row: - return None - - return dict( - zip( - ( - "response_code", - "etag", - "expires_ts", - "og", - "media_id", - "download_ts", - ), - row, - ) - ) - - return self.db.runInteraction("get_url_cache", get_url_cache_txn) - - def store_url_cache( - self, url, response_code, etag, expires_ts, og, media_id, download_ts - ): - return self.db.simple_insert( - "local_media_repository_url_cache", - { - "url": url, - "response_code": response_code, - "etag": etag, - "expires_ts": expires_ts, - "og": og, - "media_id": media_id, - "download_ts": download_ts, - }, - desc="store_url_cache", - ) - - def get_local_media_thumbnails(self, media_id): - return self.db.simple_select_list( - "local_media_repository_thumbnails", - {"media_id": media_id}, - ( - "thumbnail_width", - "thumbnail_height", - "thumbnail_method", - "thumbnail_type", - "thumbnail_length", - ), - desc="get_local_media_thumbnails", - ) - - def store_local_thumbnail( - self, - media_id, - thumbnail_width, - thumbnail_height, - thumbnail_type, - thumbnail_method, - thumbnail_length, - ): - return self.db.simple_insert( - "local_media_repository_thumbnails", - { - "media_id": media_id, - "thumbnail_width": thumbnail_width, - "thumbnail_height": thumbnail_height, - "thumbnail_method": thumbnail_method, - "thumbnail_type": thumbnail_type, - "thumbnail_length": thumbnail_length, - }, - desc="store_local_thumbnail", - ) - - def get_cached_remote_media(self, origin, media_id): - return self.db.simple_select_one( - "remote_media_cache", - {"media_origin": origin, "media_id": media_id}, - ( - "media_type", - "media_length", - "upload_name", - "created_ts", - "filesystem_id", - "quarantined_by", - ), - allow_none=True, - desc="get_cached_remote_media", - ) - - def store_cached_remote_media( - self, - origin, - media_id, - media_type, - media_length, - time_now_ms, - upload_name, - filesystem_id, - ): - return self.db.simple_insert( - "remote_media_cache", - { - "media_origin": origin, - "media_id": media_id, - "media_type": media_type, - "media_length": media_length, - "created_ts": time_now_ms, - "upload_name": upload_name, - "filesystem_id": filesystem_id, - "last_access_ts": time_now_ms, - }, - desc="store_cached_remote_media", - ) - - def update_cached_last_access_time(self, local_media, remote_media, time_ms): - """Updates the last access time of the given media - - Args: - local_media (iterable[str]): Set of media_ids - remote_media (iterable[(str, str)]): Set of (server_name, media_id) - time_ms: Current time in milliseconds - """ - - def update_cache_txn(txn): - sql = ( - "UPDATE remote_media_cache SET last_access_ts = ?" - " WHERE media_origin = ? AND media_id = ?" - ) - - txn.executemany( - sql, - ( - (time_ms, media_origin, media_id) - for media_origin, media_id in remote_media - ), - ) - - sql = ( - "UPDATE local_media_repository SET last_access_ts = ?" - " WHERE media_id = ?" - ) - - txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - - return self.db.runInteraction( - "update_cached_last_access_time", update_cache_txn - ) - - def get_remote_media_thumbnails(self, origin, media_id): - return self.db.simple_select_list( - "remote_media_cache_thumbnails", - {"media_origin": origin, "media_id": media_id}, - ( - "thumbnail_width", - "thumbnail_height", - "thumbnail_method", - "thumbnail_type", - "thumbnail_length", - "filesystem_id", - ), - desc="get_remote_media_thumbnails", - ) - - def store_remote_media_thumbnail( - self, - origin, - media_id, - filesystem_id, - thumbnail_width, - thumbnail_height, - thumbnail_type, - thumbnail_method, - thumbnail_length, - ): - return self.db.simple_insert( - "remote_media_cache_thumbnails", - { - "media_origin": origin, - "media_id": media_id, - "thumbnail_width": thumbnail_width, - "thumbnail_height": thumbnail_height, - "thumbnail_method": thumbnail_method, - "thumbnail_type": thumbnail_type, - "thumbnail_length": thumbnail_length, - "filesystem_id": filesystem_id, - }, - desc="store_remote_media_thumbnail", - ) - - def get_remote_media_before(self, before_ts): - sql = ( - "SELECT media_origin, media_id, filesystem_id" - " FROM remote_media_cache" - " WHERE last_access_ts < ?" - ) - - return self.db.execute( - "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts - ) - - def delete_remote_media(self, media_origin, media_id): - def delete_remote_media_txn(txn): - self.db.simple_delete_txn( - txn, - "remote_media_cache", - keyvalues={"media_origin": media_origin, "media_id": media_id}, - ) - self.db.simple_delete_txn( - txn, - "remote_media_cache_thumbnails", - keyvalues={"media_origin": media_origin, "media_id": media_id}, - ) - - return self.db.runInteraction("delete_remote_media", delete_remote_media_txn) - - def get_expired_url_cache(self, now_ts): - sql = ( - "SELECT media_id FROM local_media_repository_url_cache" - " WHERE expires_ts < ?" - " ORDER BY expires_ts ASC" - " LIMIT 500" - ) - - def _get_expired_url_cache_txn(txn): - txn.execute(sql, (now_ts,)) - return [row[0] for row in txn] - - return self.db.runInteraction( - "get_expired_url_cache", _get_expired_url_cache_txn - ) - - async def delete_url_cache(self, media_ids): - if len(media_ids) == 0: - return - - sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" - - def _delete_url_cache_txn(txn): - txn.executemany(sql, [(media_id,) for media_id in media_ids]) - - return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) - - def get_url_cache_media_before(self, before_ts): - sql = ( - "SELECT media_id FROM local_media_repository" - " WHERE created_ts < ? AND url_cache IS NOT NULL" - " ORDER BY created_ts ASC" - " LIMIT 500" - ) - - def _get_url_cache_media_before_txn(txn): - txn.execute(sql, (before_ts,)) - return [row[0] for row in txn] - - return self.db.runInteraction( - "get_url_cache_media_before", _get_url_cache_media_before_txn - ) - - async def delete_url_cache_media(self, media_ids): - if len(media_ids) == 0: - return - - def _delete_url_cache_media_txn(txn): - sql = "DELETE FROM local_media_repository WHERE media_id = ?" - - txn.executemany(sql, [(media_id,) for media_id in media_ids]) - - sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" - - txn.executemany(sql, [(media_id,) for media_id in media_ids]) - - return await self.db.runInteraction( - "delete_url_cache_media", _delete_url_cache_media_txn - ) diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/data_stores/main/metrics.py deleted file mode 100644 index dad5bbc602..0000000000 --- a/synapse/storage/data_stores/main/metrics.py +++ /dev/null @@ -1,128 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import typing -from collections import Counter - -from twisted.internet import defer - -from synapse.metrics import BucketCollector -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.event_push_actions import ( - EventPushActionsWorkerStore, -) -from synapse.storage.database import Database - - -class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): - """Functions to pull various metrics from the DB, for e.g. phone home - stats and prometheus metrics. - """ - - def __init__(self, database: Database, db_conn, hs): - super().__init__(database, db_conn, hs) - - # Collect metrics on the number of forward extremities that exist. - # Counter of number of extremities to count - self._current_forward_extremities_amount = ( - Counter() - ) # type: typing.Counter[int] - - BucketCollector( - "synapse_forward_extremities", - lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], - ) - - # Read the extrems every 60 minutes - def read_forward_extremities(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "read_forward_extremities", self._read_forward_extremities - ) - - hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) - - async def _read_forward_extremities(self): - def fetch(txn): - txn.execute( - """ - select count(*) c from event_forward_extremities - group by room_id - """ - ) - return txn.fetchall() - - res = await self.db.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = Counter([x[0] for x in res]) - - @defer.inlineCallbacks - def count_daily_messages(self): - """ - Returns an estimate of the number of messages sent in the last day. - - If it has been significantly less or more than one day since the last - call to this function, it will return None. - """ - - def _count_messages(txn): - sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events - WHERE type = 'm.room.message' - AND stream_ordering > ? - """ - txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() - return count - - ret = yield self.db.runInteraction("count_messages", _count_messages) - return ret - - @defer.inlineCallbacks - def count_daily_sent_messages(self): - def _count_messages(txn): - # This is good enough as if you have silly characters in your own - # hostname then thats your own fault. - like_clause = "%:" + self.hs.hostname - - sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events - WHERE type = 'm.room.message' - AND sender LIKE ? - AND stream_ordering > ? - """ - - txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - (count,) = txn.fetchone() - return count - - ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) - return ret - - @defer.inlineCallbacks - def count_daily_active_rooms(self): - def _count(txn): - sql = """ - SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events - WHERE type = 'm.room.message' - AND stream_ordering > ? - """ - txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() - return count - - ret = yield self.db.runInteraction("count_daily_active_rooms", _count) - return ret diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py deleted file mode 100644 index e459cf49a0..0000000000 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ /dev/null @@ -1,359 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import List - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database, make_in_list_sql_clause -from synapse.util.caches.descriptors import cached - -logger = logging.getLogger(__name__) - -# Number of msec of granularity to store the monthly_active_user timestamp -# This means it is not necessary to update the table on every request -LAST_SEEN_GRANULARITY = 60 * 60 * 1000 - - -class MonthlyActiveUsersWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) - self._clock = hs.get_clock() - self.hs = hs - - @cached(num_args=0) - def get_monthly_active_count(self): - """Generates current count of monthly active users - - Returns: - Defered[int]: Number of current monthly active users - """ - - def _count_users(txn): - sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" - txn.execute(sql) - (count,) = txn.fetchone() - return count - - return self.db.runInteraction("count_users", _count_users) - - @cached(num_args=0) - def get_monthly_active_count_by_service(self): - """Generates current count of monthly active users broken down by service. - A service is typically an appservice but also includes native matrix users. - Since the `monthly_active_users` table is populated from the `user_ips` table - `config.track_appservice_user_ips` must be set to `true` for this - method to return anything other than native matrix users. - - Returns: - Deferred[dict]: dict that includes a mapping between app_service_id - and the number of occurrences. - - """ - - def _count_users_by_service(txn): - sql = """ - SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0) - FROM monthly_active_users - LEFT JOIN users ON monthly_active_users.user_id=users.name - GROUP BY appservice_id; - """ - - txn.execute(sql) - result = txn.fetchall() - return dict(result) - - return self.db.runInteraction("count_users_by_service", _count_users_by_service) - - async def get_registered_reserved_users(self) -> List[str]: - """Of the reserved threepids defined in config, retrieve those that are associated - with registered users - - Returns: - User IDs of actual users that are reserved - """ - users = [] - - for tp in self.hs.config.mau_limits_reserved_threepids[ - : self.hs.config.max_mau_value - ]: - user_id = await self.hs.get_datastore().get_user_id_by_threepid( - tp["medium"], tp["address"] - ) - if user_id: - users.append(user_id) - - return users - - @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): - """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen - - """ - - return self.db.simple_select_one_onecol( - table="monthly_active_users", - keyvalues={"user_id": user_id}, - retcol="timestamp", - allow_none=True, - desc="user_last_seen_monthly_active", - ) - - -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: Database, db_conn, hs): - super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) - - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._mau_stats_only = hs.config.mau_stats_only - self._max_mau_value = hs.config.max_mau_value - - # Do not add more reserved users than the total allowable number - # cur = LoggingTransaction( - self.db.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], - ) - - def _initialise_reserved_users(self, txn, threepids): - """Ensures that reserved threepids are accounted for in the MAU table, should - be called on start up. - - Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve - """ - - # XXX what is this function trying to achieve? It upserts into - # monthly_active_users for each *registered* reserved mau user, but why? - # - # - shouldn't there already be an entry for each reserved user (at least - # if they have been active recently)? - # - # - if it's important that the timestamp is kept up to date, why do we only - # run this at startup? - - for tp in threepids: - user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) - - if user_id: - is_support = self.is_support_user_txn(txn, user_id) - if not is_support: - # We do this manually here to avoid hitting #6791 - self.db.simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - else: - logger.warning("mau limit reserved threepid %s not found in db" % tp) - - async def reap_monthly_active_users(self): - """Cleans out monthly active user table to ensure that no stale - entries exist. - """ - - def _reap_users(txn, reserved_users): - """ - Args: - reserved_users (tuple): reserved users to preserve - """ - - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - - in_clause, in_clause_args = make_in_list_sql_clause( - self.database_engine, "user_id", reserved_users - ) - - txn.execute( - "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s" - % (in_clause,), - [thirty_days_ago] + in_clause_args, - ) - - if self._limit_usage_by_mau: - # If MAU user count still exceeds the MAU threshold, then delete on - # a least recently active basis. - # Note it is not possible to write this query using OFFSET due to - # incompatibilities in how sqlite and postgres support the feature. - # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present, - # while Postgres does not require 'LIMIT', but also does not support - # negative LIMIT values. So there is no way to write it that both can - # support - - # Limit must be >= 0 for postgres - num_of_non_reserved_users_to_remove = max( - self._max_mau_value - len(reserved_users), 0 - ) - - # It is important to filter reserved users twice to guard - # against the case where the reserved user is present in the - # SELECT, meaning that a legitimate mau is deleted. - sql = """ - DELETE FROM monthly_active_users - WHERE user_id NOT IN ( - SELECT user_id FROM monthly_active_users - WHERE NOT %s - ORDER BY timestamp DESC - LIMIT ? - ) - AND NOT %s - """ % ( - in_clause, - in_clause, - ) - - query_args = ( - in_clause_args - + [num_of_non_reserved_users_to_remove] - + in_clause_args - ) - txn.execute(sql, query_args) - - # It seems poor to invalidate the whole cache. Postgres supports - # 'Returning' which would allow me to invalidate only the - # specific users, but sqlite has no way to do this and instead - # I would need to SELECT and the DELETE which without locking - # is racy. - # Have resolved to invalidate the whole cache for now and do - # something about it if and when the perf becomes significant - self._invalidate_all_cache_and_stream( - txn, self.user_last_seen_monthly_active - ) - self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) - - reserved_users = await self.get_registered_reserved_users() - await self.db.runInteraction( - "reap_monthly_active_users", _reap_users, reserved_users - ) - - @defer.inlineCallbacks - def upsert_monthly_active_user(self, user_id): - """Updates or inserts the user into the monthly active user table, which - is used to track the current MAU usage of the server - - Args: - user_id (str): user to add/update - - Returns: - Deferred - """ - # Support user never to be included in MAU stats. Note I can't easily call this - # from upsert_monthly_active_user_txn because then I need a _txn form of - # is_support_user which is complicated because I want to cache the result. - # Therefore I call it here and ignore the case where - # upsert_monthly_active_user_txn is called directly from - # _initialise_reserved_users reasoning that it would be very strange to - # include a support user in this context. - - is_support = yield self.is_support_user(user_id) - if is_support: - return - - yield self.db.runInteraction( - "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id - ) - - def upsert_monthly_active_user_txn(self, txn, user_id): - """Updates or inserts monthly active user member - - We consciously do not call is_support_txn from this method because it - is not possible to cache the response. is_support_txn will be false in - almost all cases, so it seems reasonable to call it only for - upsert_monthly_active_user and to call is_support_txn manually - for cases where upsert_monthly_active_user_txn is called directly, - like _initialise_reserved_users - - In short, don't call this method with support users. (Support users - should not appear in the MAU stats). - - Args: - txn (cursor): - user_id (str): user to add/update - - Returns: - bool: True if a new entry was created, False if an - existing one was updated. - """ - - # Am consciously deciding to lock the table on the basis that is ought - # never be a big table and alternative approaches (batching multiple - # upserts into a single txn) introduced a lot of extra complexity. - # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self.db.simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - - self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) - self._invalidate_cache_and_stream( - txn, self.get_monthly_active_count_by_service, () - ) - self._invalidate_cache_and_stream( - txn, self.user_last_seen_monthly_active, (user_id,) - ) - - return is_insert - - @defer.inlineCallbacks - def populate_monthly_active_users(self, user_id): - """Checks on the state of monthly active user limits and optionally - add the user to the monthly active tables - - Args: - user_id(str): the user_id to query - """ - if self._limit_usage_by_mau or self._mau_stats_only: - # Trial users and guests should not be included as part of MAU group - is_guest = yield self.is_guest(user_id) - if is_guest: - return - is_trial = yield self.is_trial_user(user_id) - if is_trial: - return - - last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) - now = self.hs.get_clock().time_msec() - - # We want to reduce to the total number of db writes, and are happy - # to trade accuracy of timestamp in order to lighten load. This means - # We always insert new users (where MAU threshold has not been reached), - # but only update if we have not previously seen the user for - # LAST_SEEN_GRANULARITY ms - if last_seen_timestamp is None: - # In the case where mau_stats_only is True and limit_usage_by_mau is - # False, there is no point in checking get_monthly_active_count - it - # adds no value and will break the logic if max_mau_value is exceeded. - if not self._limit_usage_by_mau: - yield self.upsert_monthly_active_user(user_id) - else: - count = yield self.get_monthly_active_count() - if count < self._max_mau_value: - yield self.upsert_monthly_active_user(user_id) - elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: - yield self.upsert_monthly_active_user(user_id) diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py deleted file mode 100644 index cc21437e92..0000000000 --- a/synapse/storage/data_stores/main/openid.py +++ /dev/null @@ -1,33 +0,0 @@ -from synapse.storage._base import SQLBaseStore - - -class OpenIdStore(SQLBaseStore): - def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self.db.simple_insert( - table="open_id_tokens", - values={ - "token": token, - "ts_valid_until_ms": ts_valid_until_ms, - "user_id": user_id, - }, - desc="insert_open_id_token", - ) - - def get_user_id_for_open_id_token(self, token, ts_now_ms): - def get_user_id_for_token_txn(txn): - sql = ( - "SELECT user_id FROM open_id_tokens" - " WHERE token = ? AND ? <= ts_valid_until_ms" - ) - - txn.execute(sql, (token, ts_now_ms)) - - rows = txn.fetchall() - if not rows: - return None - else: - return rows[0][0] - - return self.db.runInteraction( - "get_user_id_for_token", get_user_id_for_token_txn - ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py deleted file mode 100644 index 7574612619..0000000000 --- a/synapse/storage/data_stores/main/presence.py +++ /dev/null @@ -1,186 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.presence import UserPresenceState -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.iterutils import batch_iter - - -class PresenceStore(SQLBaseStore): - @defer.inlineCallbacks - def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( - len(presence_states) - ) - - with stream_ordering_manager as stream_orderings: - yield self.db.runInteraction( - "update_presence", - self._update_presence_txn, - stream_orderings, - presence_states, - ) - - return stream_orderings[-1], self._presence_id_gen.get_current_token() - - def _update_presence_txn(self, txn, stream_orderings, presence_states): - for stream_id, state in zip(stream_orderings, presence_states): - txn.call_after( - self.presence_stream_cache.entity_has_changed, state.user_id, stream_id - ) - txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) - - # Actually insert new rows - self.db.simple_insert_many_txn( - txn, - table="presence_stream", - values=[ - { - "stream_id": stream_id, - "user_id": state.user_id, - "state": state.state, - "last_active_ts": state.last_active_ts, - "last_federation_update_ts": state.last_federation_update_ts, - "last_user_sync_ts": state.last_user_sync_ts, - "status_msg": state.status_msg, - "currently_active": state.currently_active, - } - for stream_id, state in zip(stream_orderings, presence_states) - ], - ) - - # Delete old rows to stop database from getting really big - sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " - - for states in batch_iter(presence_states, 50): - clause, args = make_in_list_sql_clause( - self.database_engine, "user_id", [s.user_id for s in states] - ) - txn.execute(sql + clause, [stream_id] + list(args)) - - async def get_all_presence_updates( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: - """Get updates for presence replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - if last_id == current_id: - return [], current_id, False - - def get_all_presence_updates_txn(txn): - sql = """ - SELECT stream_id, user_id, state, last_active_ts, - last_federation_update_ts, last_user_sync_ts, - status_msg, - currently_active - FROM presence_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] - - upper_bound = current_id - limited = False - if len(updates) >= limit: - upper_bound = updates[-1][0] - limited = True - - return updates, upper_bound, limited - - return await self.db.runInteraction( - "get_all_presence_updates", get_all_presence_updates_txn - ) - - @cached() - def _get_presence_for_user(self, user_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_presence_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def get_presence_for_users(self, user_ids): - rows = yield self.db.simple_select_many_batch( - table="presence_stream", - column="user_id", - iterable=user_ids, - keyvalues={}, - retcols=( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", - ), - desc="get_presence_for_users", - ) - - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return {row["user_id"]: UserPresenceState(**row) for row in rows} - - def get_current_presence_token(self): - return self._presence_id_gen.get_current_token() - - def allow_presence_visible(self, observed_localpart, observer_userid): - return self.db.simple_insert( - table="presence_allow_inbound", - values={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="allow_presence_visible", - or_ignore=True, - ) - - def disallow_presence_visible(self, observed_localpart, observer_userid): - return self.db.simple_delete_one( - table="presence_allow_inbound", - keyvalues={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="disallow_presence_visible", - ) diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py deleted file mode 100644 index bfc9369f0b..0000000000 --- a/synapse/storage/data_stores/main/profile.py +++ /dev/null @@ -1,178 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.roommember import ProfileInfo - - -class ProfileWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_profileinfo(self, user_localpart): - try: - profile = yield self.db.simple_select_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcols=("displayname", "avatar_url"), - desc="get_profileinfo", - ) - except StoreError as e: - if e.code == 404: - # no match - return ProfileInfo(None, None) - else: - raise - - return ProfileInfo( - avatar_url=profile["avatar_url"], display_name=profile["displayname"] - ) - - def get_profile_displayname(self, user_localpart): - return self.db.simple_select_one_onecol( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcol="displayname", - desc="get_profile_displayname", - ) - - def get_profile_avatar_url(self, user_localpart): - return self.db.simple_select_one_onecol( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcol="avatar_url", - desc="get_profile_avatar_url", - ) - - def get_from_remote_profile_cache(self, user_id): - return self.db.simple_select_one( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - retcols=("displayname", "avatar_url"), - allow_none=True, - desc="get_from_remote_profile_cache", - ) - - def create_profile(self, user_localpart): - return self.db.simple_insert( - table="profiles", values={"user_id": user_localpart}, desc="create_profile" - ) - - def set_profile_displayname(self, user_localpart, new_displayname): - return self.db.simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, - desc="set_profile_displayname", - ) - - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self.db.simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, - desc="set_profile_avatar_url", - ) - - -class ProfileStore(ProfileWorkerStore): - def add_remote_profile_cache(self, user_id, displayname, avatar_url): - """Ensure we are caching the remote user's profiles. - - This should only be called when `is_subscribed_remote_profile_for_user` - would return true for the user. - """ - return self.db.simple_upsert( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - values={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="add_remote_profile_cache", - ) - - def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self.db.simple_update( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - updatevalues={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="update_remote_profile_cache", - ) - - @defer.inlineCallbacks - def maybe_delete_remote_profile_cache(self, user_id): - """Check if we still care about the remote user's profile, and if we - don't then remove their profile from the cache - """ - subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) - if not subscribed: - yield self.db.simple_delete( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - desc="delete_remote_profile_cache", - ) - - def get_remote_profile_cache_entries_that_expire(self, last_checked): - """Get all users who haven't been checked since `last_checked` - """ - - def _get_remote_profile_cache_entries_that_expire_txn(txn): - sql = """ - SELECT user_id, displayname, avatar_url - FROM remote_profile_cache - WHERE last_check < ? - """ - - txn.execute(sql, (last_checked,)) - - return self.db.cursor_to_dict(txn) - - return self.db.runInteraction( - "get_remote_profile_cache_entries_that_expire", - _get_remote_profile_cache_entries_that_expire_txn, - ) - - @defer.inlineCallbacks - def is_subscribed_remote_profile_for_user(self, user_id): - """Check whether we are interested in a remote user's profile. - """ - res = yield self.db.simple_select_one_onecol( - table="group_users", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - - if res: - return True - - res = yield self.db.simple_select_one_onecol( - table="group_invites", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - - if res: - return True diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py deleted file mode 100644 index b53fe35c33..0000000000 --- a/synapse/storage/data_stores/main/purge_events.py +++ /dev/null @@ -1,400 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Any, Tuple - -from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.state import StateGroupWorkerStore -from synapse.types import RoomStreamToken - -logger = logging.getLogger(__name__) - - -class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): - def purge_history(self, room_id, token, delete_local_events): - """Deletes room history before a certain point - - Args: - room_id (str): - - token (str): A topological token to delete events before - - delete_local_events (bool): - if True, we will delete local events as well as remote ones - (instead of just marking them as outliers and deleting their - state groups). - - Returns: - Deferred[set[int]]: The set of state groups that are referenced by - deleted events. - """ - - return self.db.runInteraction( - "purge_history", - self._purge_history_txn, - room_id, - token, - delete_local_events, - ) - - def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): - token = RoomStreamToken.parse(token_str) - - # Tables that should be pruned: - # event_auth - # event_backward_extremities - # event_edges - # event_forward_extremities - # event_json - # event_push_actions - # event_reference_hashes - # event_relations - # event_search - # event_to_state_groups - # events - # rejections - # room_depth - # state_groups - # state_groups_state - - # we will build a temporary table listing the events so that we don't - # have to keep shovelling the list back and forth across the - # connection. Annoyingly the python sqlite driver commits the - # transaction on CREATE, so let's do this first. - # - # furthermore, we might already have the table from a previous (failed) - # purge attempt, so let's drop the table first. - - txn.execute("DROP TABLE IF EXISTS events_to_purge") - - txn.execute( - "CREATE TEMPORARY TABLE events_to_purge (" - " event_id TEXT NOT NULL," - " should_delete BOOLEAN NOT NULL" - ")" - ) - - # First ensure that we're not about to delete all the forward extremeties - txn.execute( - "SELECT e.event_id, e.depth FROM events as e " - "INNER JOIN event_forward_extremities as f " - "ON e.event_id = f.event_id " - "AND e.room_id = f.room_id " - "WHERE f.room_id = ?", - (room_id,), - ) - rows = txn.fetchall() - max_depth = max(row[1] for row in rows) - - if max_depth < token.topological: - # We need to ensure we don't delete all the events from the database - # otherwise we wouldn't be able to send any events (due to not - # having any backwards extremeties) - raise SynapseError( - 400, "topological_ordering is greater than forward extremeties" - ) - - logger.info("[purge] looking for events to delete") - - should_delete_expr = "state_key IS NULL" - should_delete_params = () # type: Tuple[Any, ...] - if not delete_local_events: - should_delete_expr += " AND event_id NOT LIKE ?" - - # We include the parameter twice since we use the expression twice - should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) - - should_delete_params += (room_id, token.topological) - - # Note that we insert events that are outliers and aren't going to be - # deleted, as nothing will happen to them. - txn.execute( - "INSERT INTO events_to_purge" - " SELECT event_id, %s" - " FROM events AS e LEFT JOIN state_events USING (event_id)" - " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" - % (should_delete_expr, should_delete_expr), - should_delete_params, - ) - - # We create the indices *after* insertion as that's a lot faster. - - # create an index on should_delete because later we'll be looking for - # the should_delete / shouldn't_delete subsets - txn.execute( - "CREATE INDEX events_to_purge_should_delete" - " ON events_to_purge(should_delete)" - ) - - # We do joins against events_to_purge for e.g. calculating state - # groups to purge, etc., so lets make an index. - txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)") - - txn.execute("SELECT event_id, should_delete FROM events_to_purge") - event_rows = txn.fetchall() - logger.info( - "[purge] found %i events before cutoff, of which %i can be deleted", - len(event_rows), - sum(1 for e in event_rows if e[1]), - ) - - logger.info("[purge] Finding new backward extremities") - - # We calculate the new entries for the backward extremeties by finding - # events to be purged that are pointed to by events we're not going to - # purge. - txn.execute( - "SELECT DISTINCT e.event_id FROM events_to_purge AS e" - " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" - " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" - " WHERE ep2.event_id IS NULL" - ) - new_backwards_extrems = txn.fetchall() - - logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) - - txn.execute( - "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) - ) - - # Update backward extremeties - txn.executemany( - "INSERT INTO event_backward_extremities (room_id, event_id)" - " VALUES (?, ?)", - [(room_id, event_id) for event_id, in new_backwards_extrems], - ) - - logger.info("[purge] finding state groups referenced by deleted events") - - # Get all state groups that are referenced by events that are to be - # deleted. - txn.execute( - """ - SELECT DISTINCT state_group FROM events_to_purge - INNER JOIN event_to_state_groups USING (event_id) - """ - ) - - referenced_state_groups = {sg for sg, in txn} - logger.info( - "[purge] found %i referenced state groups", len(referenced_state_groups) - ) - - logger.info("[purge] removing events from event_to_state_groups") - txn.execute( - "DELETE FROM event_to_state_groups " - "WHERE event_id IN (SELECT event_id from events_to_purge)" - ) - for event_id, _ in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) - - # Delete all remote non-state events - for table in ( - "events", - "event_json", - "event_auth", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_relations", - "event_search", - "rejections", - ): - logger.info("[purge] removing events from %s", table) - - txn.execute( - "DELETE FROM %s WHERE event_id IN (" - " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,) - ) - - # event_push_actions lacks an index on event_id, and has one on - # (room_id, event_id) instead. - for table in ("event_push_actions",): - logger.info("[purge] removing events from %s", table) - - txn.execute( - "DELETE FROM %s WHERE room_id = ? AND event_id IN (" - " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,), - (room_id,), - ) - - # Mark all state and own events as outliers - logger.info("[purge] marking remaining events as outliers") - txn.execute( - "UPDATE events SET outlier = ?" - " WHERE event_id IN (" - " SELECT event_id FROM events_to_purge " - " WHERE NOT should_delete" - ")", - (True,), - ) - - # synapse tries to take out an exclusive lock on room_depth whenever it - # persists events (because upsert), and once we run this update, we - # will block that for the rest of our transaction. - # - # So, let's stick it at the end so that we don't block event - # persistence. - # - # We do this by calculating the minimum depth of the backwards - # extremities. However, the events in event_backward_extremities - # are ones we don't have yet so we need to look at the events that - # point to it via event_edges table. - txn.execute( - """ - SELECT COALESCE(MIN(depth), 0) - FROM event_backward_extremities AS eb - INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id - INNER JOIN events AS e ON e.event_id = eg.event_id - WHERE eb.room_id = ? - """, - (room_id,), - ) - (min_depth,) = txn.fetchone() - - logger.info("[purge] updating room_depth to %d", min_depth) - - txn.execute( - "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (min_depth, room_id), - ) - - # finally, drop the temp table. this will commit the txn in sqlite, - # so make sure to keep this actually last. - txn.execute("DROP TABLE events_to_purge") - - logger.info("[purge] done") - - return referenced_state_groups - - def purge_room(self, room_id): - """Deletes all record of a room - - Args: - room_id (str) - - Returns: - Deferred[List[int]]: The list of state groups to delete. - """ - - return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) - - def _purge_room_txn(self, txn, room_id): - # First we fetch all the state groups that should be deleted, before - # we delete that information. - txn.execute( - """ - SELECT DISTINCT state_group FROM events - INNER JOIN event_to_state_groups USING(event_id) - WHERE events.room_id = ? - """, - (room_id,), - ) - - state_groups = [row[0] for row in txn] - - # Now we delete tables which lack an index on room_id but have one on event_id - for table in ( - "event_auth", - "event_edges", - "event_push_actions_staging", - "event_reference_hashes", - "event_relations", - "event_to_state_groups", - "redactions", - "rejections", - "state_events", - ): - logger.info("[purge] removing %s from %s", room_id, table) - - txn.execute( - """ - DELETE FROM %s WHERE event_id IN ( - SELECT event_id FROM events WHERE room_id=? - ) - """ - % (table,), - (room_id,), - ) - - # and finally, the tables with an index on room_id (or no useful index) - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - # no useful index, but let's clear them anyway - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "room_account_data", - "room_tags", - "local_current_membership", - ): - logger.info("[purge] removing %s from %s", room_id, table) - txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) - - # Other tables we do NOT need to clear out: - # - # - blocked_rooms - # This is important, to make sure that we don't accidentally rejoin a blocked - # room after it was purged - # - # - user_directory - # This has a room_id column, but it is unused - # - - # Other tables that we might want to consider clearing out include: - # - # - event_reports - # Given that these are intended for abuse management my initial - # inclination is to leave them in place. - # - # - current_state_delta_stream - # - ex_outlier_stream - # - room_tags_revisions - # The problem with these is that they are largeish and there is no room_id - # index on them. In any case we should be clearing out 'stream' tables - # periodically anyway (#5888) - - # TODO: we could probably usefully do a bunch of cache invalidation here - - logger.info("[purge] done") - - return state_groups diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py deleted file mode 100644 index c229248101..0000000000 --- a/synapse/storage/data_stores/main/push_rule.py +++ /dev/null @@ -1,759 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import logging -from typing import List, Tuple, Union - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.push.baserules import list_with_base_rules -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.pusher import PusherWorkerStore -from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore -from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore -from synapse.storage.database import Database -from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import ChainedIdGenerator -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList -from synapse.util.caches.stream_change_cache import StreamChangeCache - -logger = logging.getLogger(__name__) - - -def _load_rules(rawrules, enabled_map): - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = db_to_json(rawrule["conditions"]) - rule["actions"] = db_to_json(rawrule["actions"]) - rule["default"] = False - ruleslist.append(rule) - - # We're going to be mutating this a lot, so do a deep copy - rules = list(list_with_base_rules(ruleslist)) - - for i, rule in enumerate(rules): - rule_id = rule["rule_id"] - if rule_id in enabled_map: - if rule.get("enabled", True) != bool(enabled_map[rule_id]): - # Rules are cached across users. - rule = dict(rule) - rule["enabled"] = bool(enabled_map[rule_id]) - rules[i] = rule - - return rules - - -class PushRulesWorkerStore( - ApplicationServiceWorkerStore, - ReceiptsWorkerStore, - PusherWorkerStore, - RoomMemberWorkerStore, - EventsWorkerStore, - SQLBaseStore, -): - """This is an abstract base class where subclasses must implement - `get_max_push_rules_stream_id` which can be called in the initializer. - """ - - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - - def __init__(self, database: Database, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) - - if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen = ChainedIdGenerator( - self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" - ) # type: Union[ChainedIdGenerator, SlavedIdTracker] - else: - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) - - push_rules_prefill, push_rules_id = self.db.get_cache_dict( - db_conn, - "push_rules_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self.get_max_push_rules_stream_id(), - ) - - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - push_rules_id, - prefilled_cache=push_rules_prefill, - ) - - @abc.abstractmethod - def get_max_push_rules_stream_id(self): - """Get the position of the push rules stream. - - Returns: - int - """ - raise NotImplementedError() - - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_for_user(self, user_id): - rows = yield self.db.simple_select_list( - table="push_rules", - keyvalues={"user_name": user_id}, - retcols=( - "user_name", - "rule_id", - "priority_class", - "priority", - "conditions", - "actions", - ), - desc="get_push_rules_enabled_for_user", - ) - - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - - enabled_map = yield self.get_push_rules_enabled_for_user(user_id) - - rules = _load_rules(rows, enabled_map) - - return rules - - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_enabled_for_user(self, user_id): - results = yield self.db.simple_select_list( - table="push_rules_enable", - keyvalues={"user_name": user_id}, - retcols=("user_name", "rule_id", "enabled"), - desc="get_push_rules_enabled_for_user", - ) - return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - (count,) = txn.fetchone() - return bool(count) - - return self.db.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) - - @cachedList( - cached_method_name="get_push_rules_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def bulk_get_push_rules(self, user_ids): - if not user_ids: - return {} - - results = {user_id: [] for user_id in user_ids} - - rows = yield self.db.simple_select_many_batch( - table="push_rules", - column="user_name", - iterable=user_ids, - retcols=("*",), - desc="bulk_get_push_rules", - ) - - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - - for row in rows: - results.setdefault(row["user_name"], []).append(row) - - enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) - - for user_id, rules in results.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) - - return results - - @defer.inlineCallbacks - def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): - """Copy a single push rule from one room to another for a specific user. - - Args: - new_room_id (str): ID of the new room. - user_id (str): ID of user the push rule belongs to. - rule (Dict): A push rule. - """ - # Create new rule id - rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) - new_rule_id = rule_id_scope + "/" + new_room_id - - # Change room id in each condition - for condition in rule.get("conditions", []): - if condition.get("key") == "room_id": - condition["pattern"] = new_room_id - - # Add the rule for the new room - yield self.add_push_rule( - user_id=user_id, - rule_id=new_rule_id, - priority_class=rule["priority_class"], - conditions=rule["conditions"], - actions=rule["actions"], - ) - - @defer.inlineCallbacks - def copy_push_rules_from_room_to_room_for_user( - self, old_room_id, new_room_id, user_id - ): - """Copy all of the push rules from one room to another for a specific - user. - - Args: - old_room_id (str): ID of the old room. - new_room_id (str): ID of the new room. - user_id (str): ID of user to copy push rules for. - """ - # Retrieve push rules for this user - user_push_rules = yield self.get_push_rules_for_user(user_id) - - # Get rules relating to the old room and copy them to the new room - for rule in user_push_rules: - conditions = rule.get("conditions", []) - if any( - (c.get("key") == "room_id" and c.get("pattern") == old_room_id) - for c in conditions - ): - yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) - - @defer.inlineCallbacks - def bulk_get_push_rules_for_room(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) - result = yield self._bulk_get_push_rules_for_room( - event.room_id, state_group, current_state_ids, event=event - ) - return result - - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room( - self, room_id, state_group, current_state_ids, cache_context, event=None - ): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - users_in_room = yield self._get_joined_users_from_context( - room_id, - state_group, - current_state_ids, - on_invalidate=cache_context.invalidate, - event=event, - ) - - # We ignore app service users for now. This is so that we don't fill - # up the `get_if_users_have_pushers` cache with AS entries that we - # know don't have pushers, nor even read receipts. - local_users_in_room = { - u - for u in users_in_room - if self.hs.is_mine_id(u) - and not self.get_if_app_services_interested_in_user(u) - } - - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, on_invalidate=cache_context.invalidate - ) - user_ids = { - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - } - - users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - - rules_by_user = yield self.bulk_get_push_rules( - user_ids, on_invalidate=cache_context.invalidate - ) - - rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - - return rules_by_user - - @cachedList( - cached_method_name="get_push_rules_enabled_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def bulk_get_push_rules_enabled(self, user_ids): - if not user_ids: - return {} - - results = {user_id: {} for user_id in user_ids} - - rows = yield self.db.simple_select_many_batch( - table="push_rules_enable", - column="user_name", - iterable=user_ids, - retcols=("user_name", "rule_id", "enabled"), - desc="bulk_get_push_rules_enabled", - ) - for row in rows: - enabled = bool(row["enabled"]) - results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled - return results - - async def get_all_push_rule_updates( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for push_rules replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - if last_id == current_id: - return [], current_id, False - - def get_all_push_rule_updates_txn(txn): - sql = """ - SELECT stream_id, user_id - FROM push_rules_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - txn.execute(sql, (last_id, current_id, limit)) - updates = [(stream_id, (user_id,)) for stream_id, user_id in txn] - - limited = False - upper_bound = current_id - if len(updates) == limit: - limited = True - upper_bound = updates[-1][0] - - return updates, upper_bound, limited - - return await self.db.runInteraction( - "get_all_push_rule_updates", get_all_push_rule_updates_txn - ) - - -class PushRuleStore(PushRulesWorkerStore): - @defer.inlineCallbacks - def add_push_rule( - self, - user_id, - rule_id, - priority_class, - conditions, - actions, - before=None, - after=None, - ): - conditions_json = json.dumps(conditions) - actions_json = json.dumps(actions) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - if before or after: - yield self.db.runInteraction( - "_add_push_rule_relative_txn", - self._add_push_rule_relative_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ) - else: - yield self.db.runInteraction( - "_add_push_rule_highest_priority_txn", - self._add_push_rule_highest_priority_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ) - - def _add_push_rule_relative_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ): - # Lock the table since otherwise we'll have annoying races between the - # SELECT here and the UPSERT below. - self.database_engine.lock_table(txn, "push_rules") - - relative_to_rule = before or after - - res = self.db.simple_select_one_txn( - txn, - table="push_rules", - keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, - retcols=["priority_class", "priority"], - allow_none=True, - ) - - if not res: - raise RuleNotFoundException( - "before/after rule not found: %s" % (relative_to_rule,) - ) - - base_priority_class = res["priority_class"] - base_rule_priority = res["priority"] - - if base_priority_class != priority_class: - raise InconsistentRuleException( - "Given priority class does not match class of relative rule" - ) - - if before: - # Higher priority rules are executed first, So adding a rule before - # a rule means giving it a higher priority than that rule. - new_rule_priority = base_rule_priority + 1 - else: - # We increment the priority of the existing rules to make space for - # the new rule. Therefore if we want this rule to appear after - # an existing rule we give it the priority of the existing rule, - # and then increment the priority of the existing rule. - new_rule_priority = base_rule_priority - - sql = ( - "UPDATE push_rules SET priority = priority + 1" - " WHERE user_name = ? AND priority_class = ? AND priority >= ?" - ) - - txn.execute(sql, (user_id, priority_class, new_rule_priority)) - - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - new_rule_priority, - conditions_json, - actions_json, - ) - - def _add_push_rule_highest_priority_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ): - # Lock the table since otherwise we'll have annoying races between the - # SELECT here and the UPSERT below. - self.database_engine.lock_table(txn, "push_rules") - - # find the highest priority rule in that class - sql = ( - "SELECT COUNT(*), MAX(priority) FROM push_rules" - " WHERE user_name = ? and priority_class = ?" - ) - txn.execute(sql, (user_id, priority_class)) - res = txn.fetchall() - (how_many, highest_prio) = res[0] - - new_prio = 0 - if how_many > 0: - new_prio = highest_prio + 1 - - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - new_prio, - conditions_json, - actions_json, - ) - - def _upsert_push_rule_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - conditions_json, - actions_json, - update_stream=True, - ): - """Specialised version of simple_upsert_txn that picks a push_rule_id - using the _push_rule_id_gen if it needs to insert the rule. It assumes - that the "push_rules" table is locked""" - - sql = ( - "UPDATE push_rules" - " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" - " WHERE user_name = ? AND rule_id = ?" - ) - - txn.execute( - sql, - (priority_class, priority, conditions_json, actions_json, user_id, rule_id), - ) - - if txn.rowcount == 0: - # We didn't update a row with the given rule_id so insert one - push_rule_id = self._push_rule_id_gen.get_next() - - self.db.simple_insert_txn( - txn, - table="push_rules", - values={ - "id": push_rule_id, - "user_name": user_id, - "rule_id": rule_id, - "priority_class": priority_class, - "priority": priority, - "conditions": conditions_json, - "actions": actions_json, - }, - ) - - if update_stream: - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ADD", - data={ - "priority_class": priority_class, - "priority": priority, - "conditions": conditions_json, - "actions": actions_json, - }, - ) - - @defer.inlineCallbacks - def delete_push_rule(self, user_id, rule_id): - """ - Delete a push rule. Args specify the row to be deleted and can be - any of the columns in the push_rule table, but below are the - standard ones - - Args: - user_id (str): The matrix ID of the push rule owner - rule_id (str): The rule_id of the rule to be deleted - """ - - def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self.db.simple_delete_one_txn( - txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} - ) - - self._insert_push_rules_update_txn( - txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" - ) - - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.db.runInteraction( - "delete_push_rule", - delete_push_rule_txn, - stream_id, - event_stream_ordering, - ) - - @defer.inlineCallbacks - def set_push_rule_enabled(self, user_id, rule_id, enabled): - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.db.runInteraction( - "_set_push_rule_enabled_txn", - self._set_push_rule_enabled_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - enabled, - ) - - def _set_push_rule_enabled_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled - ): - new_id = self._push_rules_enable_id_gen.get_next() - self.db.simple_upsert_txn( - txn, - "push_rules_enable", - {"user_name": user_id, "rule_id": rule_id}, - {"enabled": 1 if enabled else 0}, - {"id": new_id}, - ) - - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ENABLE" if enabled else "DISABLE", - ) - - @defer.inlineCallbacks - def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): - actions_json = json.dumps(actions) - - def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): - if is_default_rule: - # Add a dummy rule to the rules table with the user specified - # actions. - priority_class = -1 - priority = 1 - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - "[]", - actions_json, - update_stream=False, - ) - else: - self.db.simple_update_one_txn( - txn, - "push_rules", - {"user_name": user_id, "rule_id": rule_id}, - {"actions": actions_json}, - ) - - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ACTIONS", - data={"actions": actions_json}, - ) - - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.db.runInteraction( - "set_push_rule_actions", - set_push_rule_actions_txn, - stream_id, - event_stream_ordering, - ) - - def _insert_push_rules_update_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None - ): - values = { - "stream_id": stream_id, - "event_stream_ordering": event_stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": op, - } - if data is not None: - values.update(data) - - self.db.simple_insert_txn(txn, "push_rules_stream", values=values) - - txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) - txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) - txn.call_after( - self.push_rules_stream_cache.entity_has_changed, user_id, stream_id - ) - - def get_push_rules_stream_token(self): - """Get the position of the push rules stream. - Returns a pair of a stream id for the push_rules stream and the - room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_current_token() - - def get_max_push_rules_stream_id(self): - return self.get_push_rules_stream_token()[0] diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py deleted file mode 100644 index e18f1ca87c..0000000000 --- a/synapse/storage/data_stores/main/pusher.py +++ /dev/null @@ -1,354 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Iterable, Iterator, List, Tuple - -from canonicaljson import encode_canonical_json - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList - -logger = logging.getLogger(__name__) - - -class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: - """JSON-decode the data in the rows returned from the `pushers` table - - Drops any rows whose data cannot be decoded - """ - for r in rows: - dataJson = r["data"] - try: - r["data"] = db_to_json(dataJson) - except Exception as e: - logger.warning( - "Invalid JSON in data for pusher %d: %s, %s", - r["id"], - dataJson, - e.args[0], - ) - continue - - yield r - - @defer.inlineCallbacks - def user_has_pusher(self, user_id): - ret = yield self.db.simple_select_one_onecol( - "pushers", {"user_name": user_id}, "id", allow_none=True - ) - return ret is not None - - def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): - return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) - - def get_pushers_by_user_id(self, user_id): - return self.get_pushers_by({"user_name": user_id}) - - @defer.inlineCallbacks - def get_pushers_by(self, keyvalues): - ret = yield self.db.simple_select_list( - "pushers", - keyvalues, - [ - "id", - "user_name", - "access_token", - "profile_tag", - "kind", - "app_id", - "app_display_name", - "device_display_name", - "pushkey", - "ts", - "lang", - "data", - "last_stream_ordering", - "last_success", - "failing_since", - ], - desc="get_pushers_by", - ) - return self._decode_pushers_rows(ret) - - @defer.inlineCallbacks - def get_all_pushers(self): - def get_pushers(txn): - txn.execute("SELECT * FROM pushers") - rows = self.db.cursor_to_dict(txn) - - return self._decode_pushers_rows(rows) - - rows = yield self.db.runInteraction("get_all_pushers", get_pushers) - return rows - - async def get_all_updated_pushers_rows( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for pushers replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - if last_id == current_id: - return [], current_id, False - - def get_all_updated_pushers_rows_txn(txn): - sql = """ - SELECT id, user_name, app_id, pushkey - FROM pushers - WHERE ? < id AND id <= ? - ORDER BY id ASC LIMIT ? - """ - txn.execute(sql, (last_id, current_id, limit)) - updates = [ - (stream_id, (user_name, app_id, pushkey, False)) - for stream_id, user_name, app_id, pushkey in txn - ] - - sql = """ - SELECT stream_id, user_id, app_id, pushkey - FROM deleted_pushers - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? - """ - txn.execute(sql, (last_id, current_id, limit)) - updates.extend( - (stream_id, (user_name, app_id, pushkey, True)) - for stream_id, user_name, app_id, pushkey in txn - ) - - updates.sort() # Sort so that they're ordered by stream id - - limited = False - upper_bound = current_id - if len(updates) >= limit: - limited = True - upper_bound = updates[-1][0] - - return updates, upper_bound, limited - - return await self.db.runInteraction( - "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn - ) - - @cachedInlineCallbacks(num_args=1, max_entries=15000) - def get_if_user_has_pusher(self, user_id): - # This only exists for the cachedList decorator - raise NotImplementedError() - - @cachedList( - cached_method_name="get_if_user_has_pusher", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def get_if_users_have_pushers(self, user_ids): - rows = yield self.db.simple_select_many_batch( - table="pushers", - column="user_name", - iterable=user_ids, - retcols=["user_name"], - desc="get_if_users_have_pushers", - ) - - result = {user_id: False for user_id in user_ids} - result.update({r["user_name"]: True for r in rows}) - - return result - - @defer.inlineCallbacks - def update_pusher_last_stream_ordering( - self, app_id, pushkey, user_id, last_stream_ordering - ): - yield self.db.simple_update_one( - "pushers", - {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - {"last_stream_ordering": last_stream_ordering}, - desc="update_pusher_last_stream_ordering", - ) - - @defer.inlineCallbacks - def update_pusher_last_stream_ordering_and_success( - self, app_id, pushkey, user_id, last_stream_ordering, last_success - ): - """Update the last stream ordering position we've processed up to for - the given pusher. - - Args: - app_id (str) - pushkey (str) - last_stream_ordering (int) - last_success (int) - - Returns: - Deferred[bool]: True if the pusher still exists; False if it has been deleted. - """ - updated = yield self.db.simple_update( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - updatevalues={ - "last_stream_ordering": last_stream_ordering, - "last_success": last_success, - }, - desc="update_pusher_last_stream_ordering_and_success", - ) - - return bool(updated) - - @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self.db.simple_update( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - updatevalues={"failing_since": failing_since}, - desc="update_pusher_failing_since", - ) - - @defer.inlineCallbacks - def get_throttle_params_by_room(self, pusher_id): - res = yield self.db.simple_select_list( - "pusher_throttle", - {"pusher": pusher_id}, - ["room_id", "last_sent_ts", "throttle_ms"], - desc="get_throttle_params_by_room", - ) - - params_by_room = {} - for row in res: - params_by_room[row["room_id"]] = { - "last_sent_ts": row["last_sent_ts"], - "throttle_ms": row["throttle_ms"], - } - - return params_by_room - - @defer.inlineCallbacks - def set_throttle_params(self, pusher_id, room_id, params): - # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so simple_upsert will retry - yield self.db.simple_upsert( - "pusher_throttle", - {"pusher": pusher_id, "room_id": room_id}, - params, - desc="set_throttle_params", - lock=False, - ) - - -class PusherStore(PusherWorkerStore): - def get_pushers_stream_token(self): - return self._pushers_id_gen.get_current_token() - - @defer.inlineCallbacks - def add_pusher( - self, - user_id, - access_token, - kind, - app_id, - app_display_name, - device_display_name, - pushkey, - pushkey_ts, - lang, - data, - last_stream_ordering, - profile_tag="", - ): - with self._pushers_id_gen.get_next() as stream_id: - # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so simple_upsert will retry - yield self.db.simple_upsert( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - values={ - "access_token": access_token, - "kind": kind, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "ts": pushkey_ts, - "lang": lang, - "data": bytearray(encode_canonical_json(data)), - "last_stream_ordering": last_stream_ordering, - "profile_tag": profile_tag, - "id": stream_id, - }, - desc="add_pusher", - lock=False, - ) - - user_has_pusher = self.get_if_user_has_pusher.cache.get( - (user_id,), None, update_metrics=False - ) - - if user_has_pusher is not True: - # invalidate, since we the user might not have had a pusher before - yield self.db.runInteraction( - "add_pusher", - self._invalidate_cache_and_stream, - self.get_if_user_has_pusher, - (user_id,), - ) - - @defer.inlineCallbacks - def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): - def delete_pusher_txn(txn, stream_id): - self._invalidate_cache_and_stream( - txn, self.get_if_user_has_pusher, (user_id,) - ) - - self.db.simple_delete_one_txn( - txn, - "pushers", - {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - ) - - # it's possible for us to end up with duplicate rows for - # (app_id, pushkey, user_id) at different stream_ids, but that - # doesn't really matter. - self.db.simple_insert_txn( - txn, - table="deleted_pushers", - values={ - "stream_id": stream_id, - "app_id": app_id, - "pushkey": pushkey, - "user_id": user_id, - }, - ) - - with self._pushers_id_gen.get_next() as stream_id: - yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py deleted file mode 100644 index 1d723f2d34..0000000000 --- a/synapse/storage/data_stores/main/receipts.py +++ /dev/null @@ -1,589 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import logging -from typing import List, Tuple - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database -from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from synapse.util.caches.stream_change_cache import StreamChangeCache - -logger = logging.getLogger(__name__) - - -class ReceiptsWorkerStore(SQLBaseStore): - """This is an abstract base class where subclasses must implement - `get_max_receipt_stream_id` which can be called in the initializer. - """ - - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - - def __init__(self, database: Database, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) - - self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() - ) - - @abc.abstractmethod - def get_max_receipt_stream_id(self): - """Get the current max stream ID for receipts stream - - Returns: - int - """ - raise NotImplementedError() - - @cachedInlineCallbacks() - def get_users_with_read_receipts_in_room(self, room_id): - receipts = yield self.get_receipts_for_room(room_id, "m.read") - return {r["user_id"] for r in receipts} - - @cached(num_args=2) - def get_receipts_for_room(self, room_id, receipt_type): - return self.db.simple_select_list( - table="receipts_linearized", - keyvalues={"room_id": room_id, "receipt_type": receipt_type}, - retcols=("user_id", "event_id"), - desc="get_receipts_for_room", - ) - - @cached(num_args=3) - def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.db.simple_select_one_onecol( - table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, - retcol="event_id", - desc="get_own_receipt_for_user", - allow_none=True, - ) - - @cachedInlineCallbacks(num_args=2) - def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self.db.simple_select_list( - table="receipts_linearized", - keyvalues={"user_id": user_id, "receipt_type": receipt_type}, - retcols=("room_id", "event_id"), - desc="get_receipts_for_user", - ) - - return {row["room_id"]: row["event_id"] for row in rows} - - @defer.inlineCallbacks - def get_receipts_for_user_with_orderings(self, user_id, receipt_type): - def f(txn): - sql = ( - "SELECT rl.room_id, rl.event_id," - " e.topological_ordering, e.stream_ordering" - " FROM receipts_linearized AS rl" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE rl.room_id = e.room_id" - " AND rl.event_id = e.event_id" - " AND user_id = ?" - ) - txn.execute(sql, (user_id,)) - return txn.fetchall() - - rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f) - return { - row[0]: { - "event_id": row[1], - "topological_ordering": row[2], - "stream_ordering": row[3], - } - for row in rows - } - - @defer.inlineCallbacks - def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): - """Get receipts for multiple rooms for sending to clients. - - Args: - room_ids (list): List of room_ids. - to_key (int): Max stream id to fetch receipts upto. - from_key (int): Min stream id to fetch receipts from. None fetches - from the start. - - Returns: - list: A list of receipts. - """ - room_ids = set(room_ids) - - if from_key is not None: - # Only ask the database about rooms where there have been new - # receipts added since `from_key` - room_ids = yield self._receipts_stream_cache.get_entities_changed( - room_ids, from_key - ) - - results = yield self._get_linearized_receipts_for_rooms( - room_ids, to_key, from_key=from_key - ) - - return [ev for res in results.values() for ev in res] - - def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): - """Get receipts for a single room for sending to clients. - - Args: - room_ids (str): The room id. - to_key (int): Max stream id to fetch receipts upto. - from_key (int): Min stream id to fetch receipts from. None fetches - from the start. - - Returns: - Deferred[list]: A list of receipts. - """ - if from_key is not None: - # Check the cache first to see if any new receipts have been added - # since`from_key`. If not we can no-op. - if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): - defer.succeed([]) - - return self._get_linearized_receipts_for_room(room_id, to_key, from_key) - - @cachedInlineCallbacks(num_args=3, tree=True) - def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): - """See get_linearized_receipts_for_room - """ - - def f(txn): - if from_key: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id = ? AND stream_id > ? AND stream_id <= ?" - ) - - txn.execute(sql, (room_id, from_key, to_key)) - else: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id = ? AND stream_id <= ?" - ) - - txn.execute(sql, (room_id, to_key)) - - rows = self.db.cursor_to_dict(txn) - - return rows - - rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f) - - if not rows: - return [] - - content = {} - for row in rows: - content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ - row["user_id"] - ] = db_to_json(row["data"]) - - return [{"type": "m.receipt", "room_id": room_id, "content": content}] - - @cachedList( - cached_method_name="_get_linearized_receipts_for_room", - list_name="room_ids", - num_args=3, - inlineCallbacks=True, - ) - def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): - if not room_ids: - return {} - - def f(txn): - if from_key: - sql = """ - SELECT * FROM receipts_linearized WHERE - stream_id > ? AND stream_id <= ? AND - """ - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", room_ids - ) - - txn.execute(sql + clause, [from_key, to_key] + list(args)) - else: - sql = """ - SELECT * FROM receipts_linearized WHERE - stream_id <= ? AND - """ - - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", room_ids - ) - - txn.execute(sql + clause, [to_key] + list(args)) - - return self.db.cursor_to_dict(txn) - - txn_results = yield self.db.runInteraction( - "_get_linearized_receipts_for_rooms", f - ) - - results = {} - for row in txn_results: - # We want a single event per room, since we want to batch the - # receipts by room, event and type. - room_event = results.setdefault( - row["room_id"], - {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, - ) - - # The content is of the form: - # {"$foo:bar": { "read": { "@user:host": }, .. }, .. } - event_entry = room_event["content"].setdefault(row["event_id"], {}) - receipt_type = event_entry.setdefault(row["receipt_type"], {}) - - receipt_type[row["user_id"]] = db_to_json(row["data"]) - - results = { - room_id: [results[room_id]] if room_id in results else [] - for room_id in room_ids - } - return results - - def get_users_sent_receipts_between(self, last_id: int, current_id: int): - """Get all users who sent receipts between `last_id` exclusive and - `current_id` inclusive. - - Returns: - Deferred[List[str]] - """ - - if last_id == current_id: - return defer.succeed([]) - - def _get_users_sent_receipts_between_txn(txn): - sql = """ - SELECT DISTINCT user_id FROM receipts_linearized - WHERE ? < stream_id AND stream_id <= ? - """ - txn.execute(sql, (last_id, current_id)) - - return [r[0] for r in txn] - - return self.db.runInteraction( - "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn - ) - - async def get_all_updated_receipts( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: - """Get updates for receipts replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - if last_id == current_id: - return [], current_id, False - - def get_all_updated_receipts_txn(txn): - sql = """ - SELECT stream_id, room_id, receipt_type, user_id, event_id, data - FROM receipts_linearized - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - txn.execute(sql, (last_id, current_id, limit)) - - updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] - - limited = False - upper_bound = current_id - - if len(updates) == limit: - limited = True - upper_bound = updates[-1][0] - - return updates, upper_bound, limited - - return await self.db.runInteraction( - "get_all_updated_receipts", get_all_updated_receipts_txn - ) - - def _invalidate_get_users_with_receipts_in_room( - self, room_id, receipt_type, user_id - ): - if receipt_type != "m.read": - return - - # Returns either an ObservableDeferred or the raw result - res = self.get_users_with_read_receipts_in_room.cache.get( - room_id, None, update_metrics=False - ) - - # first handle the ObservableDeferred case - if isinstance(res, ObservableDeferred): - if res.has_called(): - res = res.get_result() - else: - res = None - - if res and user_id in res: - # We'd only be adding to the set, so no point invalidating if the - # user is already there - return - - self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - - -class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, database: Database, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - - super(ReceiptsStore, self).__init__(database, db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() - - def insert_linearized_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_id, data, stream_id - ): - """Inserts a read-receipt into the database if it's newer than the current RR - - Returns: int|None - None if the RR is older than the current RR - otherwise, the rx timestamp of the event that the RR corresponds to - (or 0 if the event is unknown) - """ - res = self.db.simple_select_one_txn( - txn, - table="events", - retcols=["stream_ordering", "received_ts"], - keyvalues={"event_id": event_id}, - allow_none=True, - ) - - stream_ordering = int(res["stream_ordering"]) if res else None - rx_ts = res["received_ts"] if res else 0 - - # We don't want to clobber receipts for more recent events, so we - # have to compare orderings of existing receipts - if stream_ordering is not None: - sql = ( - "SELECT stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized as r USING (event_id, room_id)" - " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" - ) - txn.execute(sql, (room_id, receipt_type, user_id)) - - for so, eid in txn: - if int(so) >= stream_ordering: - logger.debug( - "Ignoring new receipt for %s in favour of existing " - "one for later event %s", - event_id, - eid, - ) - return None - - txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) - txn.call_after( - self._invalidate_get_users_with_receipts_in_room, - room_id, - receipt_type, - user_id, - ) - txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) - # FIXME: This shouldn't invalidate the whole cache - txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) - ) - - txn.call_after( - self._receipts_stream_cache.entity_has_changed, room_id, stream_id - ) - - txn.call_after( - self.get_last_receipt_event_id_for_user.invalidate, - (user_id, room_id, receipt_type), - ) - - self.db.simple_upsert_txn( - txn, - table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, - values={ - "stream_id": stream_id, - "event_id": event_id, - "data": json.dumps(data), - }, - # receipts_linearized has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, - ) - - if receipt_type == "m.read" and stream_ordering is not None: - self._remove_old_push_actions_before_txn( - txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering - ) - - return rx_ts - - @defer.inlineCallbacks - def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): - """Insert a receipt, either from local client or remote server. - - Automatically does conversion between linearized and graph - representations. - """ - if not event_ids: - return - - if len(event_ids) == 1: - linearized_event_id = event_ids[0] - else: - # we need to points in graph -> linearized form. - # TODO: Make this better. - def graph_to_linear(txn): - clause, args = make_in_list_sql_clause( - self.database_engine, "event_id", event_ids - ) - - sql = """ - SELECT event_id WHERE room_id = ? AND stream_ordering IN ( - SELECT max(stream_ordering) WHERE %s - ) - """ % ( - clause, - ) - - txn.execute(sql, [room_id] + list(args)) - rows = txn.fetchall() - if rows: - return rows[0][0] - else: - raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - - linearized_event_id = yield self.db.runInteraction( - "insert_receipt_conv", graph_to_linear - ) - - stream_id_manager = self._receipts_id_gen.get_next() - with stream_id_manager as stream_id: - event_ts = yield self.db.runInteraction( - "insert_linearized_receipt", - self.insert_linearized_receipt_txn, - room_id, - receipt_type, - user_id, - linearized_event_id, - data, - stream_id=stream_id, - ) - - if event_ts is None: - return None - - now = self._clock.time_msec() - logger.debug( - "RR for event %s in %s (%i ms old)", - linearized_event_id, - room_id, - now - event_ts, - ) - - yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) - - max_persisted_id = self._receipts_id_gen.get_current_token() - - return stream_id, max_persisted_id - - def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.db.runInteraction( - "insert_graph_receipt", - self.insert_graph_receipt_txn, - room_id, - receipt_type, - user_id, - event_ids, - data, - ) - - def insert_graph_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_ids, data - ): - txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) - txn.call_after( - self._invalidate_get_users_with_receipts_in_room, - room_id, - receipt_type, - user_id, - ) - txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) - # FIXME: This shouldn't invalidate the whole cache - txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) - ) - - self.db.simple_delete_txn( - txn, - table="receipts_graph", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, - ) - self.db.simple_insert_txn( - txn, - table="receipts_graph", - values={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - "event_ids": json.dumps(event_ids), - "data": json.dumps(data), - }, - ) diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py deleted file mode 100644 index 27d2c5028c..0000000000 --- a/synapse/storage/data_stores/main/registration.py +++ /dev/null @@ -1,1582 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -from typing import Optional - -from twisted.internet import defer -from twisted.internet.defer import Deferred - -from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database -from synapse.storage.types import Cursor -from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks - -THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 - -logger = logging.getLogger(__name__) - - -class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) - - self.config = hs.config - self.clock = hs.get_clock() - - self._user_id_seq = build_sequence_generator( - database.engine, find_max_generated_user_id_localpart, "user_id_seq", - ) - - @cached() - def get_user_by_id(self, user_id): - return self.db.simple_select_one( - table="users", - keyvalues={"name": user_id}, - retcols=[ - "name", - "password_hash", - "is_guest", - "admin", - "consent_version", - "consent_server_notice_sent", - "appservice_id", - "creation_ts", - "user_type", - "deactivated", - ], - allow_none=True, - desc="get_user_by_id", - ) - - @defer.inlineCallbacks - def is_trial_user(self, user_id): - """Checks if user is in the "trial" period, i.e. within the first - N days of registration defined by `mau_trial_days` config - - Args: - user_id (str) - - Returns: - Deferred[bool] - """ - - info = yield self.get_user_by_id(user_id) - if not info: - return False - - now = self.clock.time_msec() - trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 - is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms - return is_trial - - @cached() - def get_user_by_access_token(self, token): - """Get a user from the given access token. - - Args: - token (str): The access token of a user. - Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. - """ - return self.db.runInteraction( - "get_user_by_access_token", self._query_for_auth, token - ) - - @cachedInlineCallbacks() - def get_expiration_ts_for_user(self, user_id): - """Get the expiration timestamp for the account bearing a given user ID. - - Args: - user_id (str): The ID of the user. - Returns: - defer.Deferred: None, if the account has no expiration timestamp, - otherwise int representation of the timestamp (as a number of - milliseconds since epoch). - """ - res = yield self.db.simple_select_one_onecol( - table="account_validity", - keyvalues={"user_id": user_id}, - retcol="expiration_ts_ms", - allow_none=True, - desc="get_expiration_ts_for_user", - ) - return res - - @defer.inlineCallbacks - def set_account_validity_for_user( - self, user_id, expiration_ts, email_sent, renewal_token=None - ): - """Updates the account validity properties of the given account, with the - given values. - - Args: - user_id (str): ID of the account to update properties for. - expiration_ts (int): New expiration date, as a timestamp in milliseconds - since epoch. - email_sent (bool): True means a renewal email has been sent for this - account and there's no need to send another one for the current validity - period. - renewal_token (str): Renewal token the user can use to extend the validity - of their account. Defaults to no token. - """ - - def set_account_validity_for_user_txn(txn): - self.db.simple_update_txn( - txn=txn, - table="account_validity", - keyvalues={"user_id": user_id}, - updatevalues={ - "expiration_ts_ms": expiration_ts, - "email_sent": email_sent, - "renewal_token": renewal_token, - }, - ) - self._invalidate_cache_and_stream( - txn, self.get_expiration_ts_for_user, (user_id,) - ) - - yield self.db.runInteraction( - "set_account_validity_for_user", set_account_validity_for_user_txn - ) - - @defer.inlineCallbacks - def set_renewal_token_for_user(self, user_id, renewal_token): - """Defines a renewal token for a given user. - - Args: - user_id (str): ID of the user to set the renewal token for. - renewal_token (str): Random unique string that will be used to renew the - user's account. - - Raises: - StoreError: The provided token is already set for another user. - """ - yield self.db.simple_update_one( - table="account_validity", - keyvalues={"user_id": user_id}, - updatevalues={"renewal_token": renewal_token}, - desc="set_renewal_token_for_user", - ) - - @defer.inlineCallbacks - def get_user_from_renewal_token(self, renewal_token): - """Get a user ID from a renewal token. - - Args: - renewal_token (str): The renewal token to perform the lookup with. - - Returns: - defer.Deferred[str]: The ID of the user to which the token belongs. - """ - res = yield self.db.simple_select_one_onecol( - table="account_validity", - keyvalues={"renewal_token": renewal_token}, - retcol="user_id", - desc="get_user_from_renewal_token", - ) - - return res - - @defer.inlineCallbacks - def get_renewal_token_for_user(self, user_id): - """Get the renewal token associated with a given user ID. - - Args: - user_id (str): The user ID to lookup a token for. - - Returns: - defer.Deferred[str]: The renewal token associated with this user ID. - """ - res = yield self.db.simple_select_one_onecol( - table="account_validity", - keyvalues={"user_id": user_id}, - retcol="renewal_token", - desc="get_renewal_token_for_user", - ) - - return res - - @defer.inlineCallbacks - def get_users_expiring_soon(self): - """Selects users whose account will expire in the [now, now + renew_at] time - window (see configuration for account_validity for information on what renew_at - refers to). - - Returns: - Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] - """ - - def select_users_txn(txn, now_ms, renew_at): - sql = ( - "SELECT user_id, expiration_ts_ms FROM account_validity" - " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" - ) - values = [False, now_ms, renew_at] - txn.execute(sql, values) - return self.db.cursor_to_dict(txn) - - res = yield self.db.runInteraction( - "get_users_expiring_soon", - select_users_txn, - self.clock.time_msec(), - self.config.account_validity.renew_at, - ) - - return res - - @defer.inlineCallbacks - def set_renewal_mail_status(self, user_id, email_sent): - """Sets or unsets the flag that indicates whether a renewal email has been sent - to the user (and the user hasn't renewed their account yet). - - Args: - user_id (str): ID of the user to set/unset the flag for. - email_sent (bool): Flag which indicates whether a renewal email has been sent - to this user. - """ - yield self.db.simple_update_one( - table="account_validity", - keyvalues={"user_id": user_id}, - updatevalues={"email_sent": email_sent}, - desc="set_renewal_mail_status", - ) - - @defer.inlineCallbacks - def delete_account_validity_for_user(self, user_id): - """Deletes the entry for the given user in the account validity table, removing - their expiration date and renewal token. - - Args: - user_id (str): ID of the user to remove from the account validity table. - """ - yield self.db.simple_delete_one( - table="account_validity", - keyvalues={"user_id": user_id}, - desc="delete_account_validity_for_user", - ) - - async def is_server_admin(self, user): - """Determines if a user is an admin of this homeserver. - - Args: - user (UserID): user ID of the user to test - - Returns (bool): - true iff the user is a server admin, false otherwise. - """ - res = await self.db.simple_select_one_onecol( - table="users", - keyvalues={"name": user.to_string()}, - retcol="admin", - allow_none=True, - desc="is_server_admin", - ) - - return bool(res) if res else False - - def set_server_admin(self, user, admin): - """Sets whether a user is an admin of this homeserver. - - Args: - user (UserID): user ID of the user to test - admin (bool): true iff the user is to be a server admin, - false otherwise. - """ - - def set_server_admin_txn(txn): - self.db.simple_update_one_txn( - txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} - ) - self._invalidate_cache_and_stream( - txn, self.get_user_by_id, (user.to_string(),) - ) - - return self.db.runInteraction("set_server_admin", set_server_admin_txn) - - def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," - " access_tokens.device_id, access_tokens.valid_until_ms" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) - - txn.execute(sql, (token,)) - rows = self.db.cursor_to_dict(txn) - if rows: - return rows[0] - - return None - - @cachedInlineCallbacks() - def is_real_user(self, user_id): - """Determines if the user is a real user, ie does not have a 'user_type'. - - Args: - user_id (str): user id to test - - Returns: - Deferred[bool]: True if user 'user_type' is null or empty string - """ - res = yield self.db.runInteraction( - "is_real_user", self.is_real_user_txn, user_id - ) - return res - - @cached() - def is_support_user(self, user_id): - """Determines if the user is of type UserTypes.SUPPORT - - Args: - user_id (str): user id to test - - Returns: - Deferred[bool]: True if user is of type UserTypes.SUPPORT - """ - return self.db.runInteraction( - "is_support_user", self.is_support_user_txn, user_id - ) - - def is_real_user_txn(self, txn, user_id): - res = self.db.simple_select_one_onecol_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - retcol="user_type", - allow_none=True, - ) - return res is None - - def is_support_user_txn(self, txn, user_id): - res = self.db.simple_select_one_onecol_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - retcol="user_type", - allow_none=True, - ) - return True if res == UserTypes.SUPPORT else False - - def get_users_by_id_case_insensitive(self, user_id): - """Gets users that match user_id case insensitively. - Returns a mapping of user_id -> password_hash. - """ - - def f(txn): - sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)" - txn.execute(sql, (user_id,)) - return dict(txn) - - return self.db.runInteraction("get_users_by_id_case_insensitive", f) - - async def get_user_by_external_id( - self, auth_provider: str, external_id: str - ) -> str: - """Look up a user by their external auth id - - Args: - auth_provider: identifier for the remote auth provider - external_id: id on that system - - Returns: - str|None: the mxid of the user, or None if they are not known - """ - return await self.db.simple_select_one_onecol( - table="user_external_ids", - keyvalues={"auth_provider": auth_provider, "external_id": external_id}, - retcol="user_id", - allow_none=True, - desc="get_user_by_external_id", - ) - - @defer.inlineCallbacks - def count_all_users(self): - """Counts all users registered on the homeserver.""" - - def _count_users(txn): - txn.execute("SELECT COUNT(*) AS users FROM users") - rows = self.db.cursor_to_dict(txn) - if rows: - return rows[0]["users"] - return 0 - - ret = yield self.db.runInteraction("count_users", _count_users) - return ret - - def count_daily_user_type(self): - """ - Counts 1) native non guest users - 2) native guests users - 3) bridged users - who registered on the homeserver in the past 24 hours - """ - - def _count_daily_user_type(txn): - yesterday = int(self._clock.time()) - (60 * 60 * 24) - - sql = """ - SELECT user_type, COALESCE(count(*), 0) AS count FROM ( - SELECT - CASE - WHEN is_guest=0 AND appservice_id IS NULL THEN 'native' - WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest' - WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged' - END AS user_type - FROM users - WHERE creation_ts > ? - ) AS t GROUP BY user_type - """ - results = {"native": 0, "guest": 0, "bridged": 0} - txn.execute(sql, (yesterday,)) - for row in txn: - results[row[0]] = row[1] - return results - - return self.db.runInteraction("count_daily_user_type", _count_daily_user_type) - - @defer.inlineCallbacks - def count_nonbridged_users(self): - def _count_users(txn): - txn.execute( - """ - SELECT COALESCE(COUNT(*), 0) FROM users - WHERE appservice_id IS NULL - """ - ) - (count,) = txn.fetchone() - return count - - ret = yield self.db.runInteraction("count_users", _count_users) - return ret - - @defer.inlineCallbacks - def count_real_users(self): - """Counts all users without a special user_type registered on the homeserver.""" - - def _count_users(txn): - txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") - rows = self.db.cursor_to_dict(txn) - if rows: - return rows[0]["users"] - return 0 - - ret = yield self.db.runInteraction("count_real_users", _count_users) - return ret - - async def generate_user_id(self) -> str: - """Generate a suitable localpart for a guest user - - Returns: a (hopefully) free localpart - """ - next_id = await self.db.runInteraction( - "generate_user_id", self._user_id_seq.get_next_id_txn - ) - - return str(next_id) - - async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: - """Returns user id from threepid - - Args: - medium: threepid medium e.g. email - address: threepid address e.g. me@example.com - - Returns: - The user ID or None if no user id/threepid mapping exists - """ - user_id = await self.db.runInteraction( - "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address - ) - return user_id - - def get_user_id_by_threepid_txn(self, txn, medium, address): - """Returns user id from threepid - - Args: - txn (cursor): - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com - - Returns: - str|None: user id or None if no user id/threepid mapping exists - """ - ret = self.db.simple_select_one_txn( - txn, - "user_threepids", - {"medium": medium, "address": address}, - ["user_id"], - True, - ) - if ret: - return ret["user_id"] - return None - - @defer.inlineCallbacks - def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self.db.simple_upsert( - "user_threepids", - {"medium": medium, "address": address}, - {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, - ) - - @defer.inlineCallbacks - def user_get_threepids(self, user_id): - ret = yield self.db.simple_select_list( - "user_threepids", - {"user_id": user_id}, - ["medium", "address", "validated_at", "added_at"], - "user_get_threepids", - ) - return ret - - def user_delete_threepid(self, user_id, medium, address): - return self.db.simple_delete( - "user_threepids", - keyvalues={"user_id": user_id, "medium": medium, "address": address}, - desc="user_delete_threepid", - ) - - def user_delete_threepids(self, user_id: str): - """Delete all threepid this user has bound - - Args: - user_id: The user id to delete all threepids of - - """ - return self.db.simple_delete( - "user_threepids", - keyvalues={"user_id": user_id}, - desc="user_delete_threepids", - ) - - def add_user_bound_threepid(self, user_id, medium, address, id_server): - """The server proxied a bind request to the given identity server on - behalf of the given user. We need to remember this in case the user - asks us to unbind the threepid. - - Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Deferred - """ - # We need to use an upsert, in case they user had already bound the - # threepid - return self.db.simple_upsert( - table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - "id_server": id_server, - }, - values={}, - insertion_values={}, - desc="add_user_bound_threepid", - ) - - def user_get_bound_threepids(self, user_id): - """Get the threepids that a user has bound to an identity server through the homeserver - The homeserver remembers where binds to an identity server occurred. Using this - method can retrieve those threepids. - - Args: - user_id (str): The ID of the user to retrieve threepids for - - Returns: - Deferred[list[dict]]: List of dictionaries containing the following: - medium (str): The medium of the threepid (e.g "email") - address (str): The address of the threepid (e.g "bob@example.com") - """ - return self.db.simple_select_list( - table="user_threepid_id_server", - keyvalues={"user_id": user_id}, - retcols=["medium", "address"], - desc="user_get_bound_threepids", - ) - - def remove_user_bound_threepid(self, user_id, medium, address, id_server): - """The server proxied an unbind request to the given identity server on - behalf of the given user, so we remove the mapping of threepid to - identity server. - - Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Deferred - """ - return self.db.simple_delete( - table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - "id_server": id_server, - }, - desc="remove_user_bound_threepid", - ) - - def get_id_servers_user_bound(self, user_id, medium, address): - """Get the list of identity servers that the server proxied bind - requests to for given user and threepid - - Args: - user_id (str) - medium (str) - address (str) - - Returns: - Deferred[list[str]]: Resolves to a list of identity servers - """ - return self.db.simple_select_onecol( - table="user_threepid_id_server", - keyvalues={"user_id": user_id, "medium": medium, "address": address}, - retcol="id_server", - desc="get_id_servers_user_bound", - ) - - @cachedInlineCallbacks() - def get_user_deactivated_status(self, user_id): - """Retrieve the value for the `deactivated` property for the provided user. - - Args: - user_id (str): The ID of the user to retrieve the status for. - - Returns: - defer.Deferred(bool): The requested value. - """ - - res = yield self.db.simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="deactivated", - desc="get_user_deactivated_status", - ) - - # Convert the integer into a boolean. - return res == 1 - - def get_threepid_validation_session( - self, medium, client_secret, address=None, sid=None, validated=True - ): - """Gets a session_id and last_send_attempt (if available) for a - combination of validation metadata - - Args: - medium (str|None): The medium of the 3PID - address (str|None): The address of the 3PID - sid (str|None): The ID of the validation session - client_secret (str): A unique string provided by the client to help identify this - validation attempt - validated (bool|None): Whether sessions should be filtered by - whether they have been validated already or not. None to - perform no filtering - - Returns: - Deferred[dict|None]: A dict containing the following: - * address - address of the 3pid - * medium - medium of the 3pid - * client_secret - a secret provided by the client for this validation session - * session_id - ID of the validation session - * send_attempt - a number serving to dedupe send attempts for this session - * validated_at - timestamp of when this session was validated if so - - Otherwise None if a validation session is not found - """ - if not client_secret: - raise SynapseError( - 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM - ) - - keyvalues = {"client_secret": client_secret} - if medium: - keyvalues["medium"] = medium - if address: - keyvalues["address"] = address - if sid: - keyvalues["session_id"] = sid - - assert address or sid - - def get_threepid_validation_session_txn(txn): - sql = """ - SELECT address, session_id, medium, client_secret, - last_send_attempt, validated_at - FROM threepid_validation_session WHERE %s - """ % ( - " AND ".join("%s = ?" % k for k in keyvalues.keys()), - ) - - if validated is not None: - sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") - - sql += " LIMIT 1" - - txn.execute(sql, list(keyvalues.values())) - rows = self.db.cursor_to_dict(txn) - if not rows: - return None - - return rows[0] - - return self.db.runInteraction( - "get_threepid_validation_session", get_threepid_validation_session_txn - ) - - def delete_threepid_session(self, session_id): - """Removes a threepid validation session from the database. This can - be done after validation has been performed and whatever action was - waiting on it has been carried out - - Args: - session_id (str): The ID of the session to delete - """ - - def delete_threepid_session_txn(txn): - self.db.simple_delete_txn( - txn, - table="threepid_validation_token", - keyvalues={"session_id": session_id}, - ) - self.db.simple_delete_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - ) - - return self.db.runInteraction( - "delete_threepid_session", delete_threepid_session_txn - ) - - -class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: Database, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) - - self.clock = hs.get_clock() - self.config = hs.config - - self.db.updates.register_background_index_update( - "access_tokens_device_index", - index_name="access_tokens_device_id", - table="access_tokens", - columns=["user_id", "device_id"], - ) - - self.db.updates.register_background_index_update( - "users_creation_ts", - index_name="users_creation_ts", - table="users", - columns=["creation_ts"], - ) - - # we no longer use refresh tokens, but it's possible that some people - # might have a background update queued to build this index. Just - # clear the background update. - self.db.updates.register_noop_background_update("refresh_tokens_device_index") - - self.db.updates.register_background_update_handler( - "user_threepids_grandfather", self._bg_user_threepids_grandfather - ) - - self.db.updates.register_background_update_handler( - "users_set_deactivated_flag", self._background_update_set_deactivated_flag - ) - - @defer.inlineCallbacks - def _background_update_set_deactivated_flag(self, progress, batch_size): - """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 - for each of them. - """ - - last_user = progress.get("user_id", "") - - def _background_update_set_deactivated_flag_txn(txn): - txn.execute( - """ - SELECT - users.name, - COUNT(access_tokens.token) AS count_tokens, - COUNT(user_threepids.address) AS count_threepids - FROM users - LEFT JOIN access_tokens ON (access_tokens.user_id = users.name) - LEFT JOIN user_threepids ON (user_threepids.user_id = users.name) - WHERE (users.password_hash IS NULL OR users.password_hash = '') - AND (users.appservice_id IS NULL OR users.appservice_id = '') - AND users.is_guest = 0 - AND users.name > ? - GROUP BY users.name - ORDER BY users.name ASC - LIMIT ?; - """, - (last_user, batch_size), - ) - - rows = self.db.cursor_to_dict(txn) - - if not rows: - return True, 0 - - rows_processed_nb = 0 - - for user in rows: - if not user["count_tokens"] and not user["count_threepids"]: - self.set_user_deactivated_status_txn(txn, user["name"], True) - rows_processed_nb += 1 - - logger.info("Marked %d rows as deactivated", rows_processed_nb) - - self.db.updates._background_update_progress_txn( - txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} - ) - - if batch_size > len(rows): - return True, len(rows) - else: - return False, len(rows) - - end, nb_processed = yield self.db.runInteraction( - "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn - ) - - if end: - yield self.db.updates._end_background_update("users_set_deactivated_flag") - - return nb_processed - - @defer.inlineCallbacks - def _bg_user_threepids_grandfather(self, progress, batch_size): - """We now track which identity servers a user binds their 3PID to, so - we need to handle the case of existing bindings where we didn't track - this. - - We do this by grandfathering in existing user threepids assuming that - they used one of the server configured trusted identity servers. - """ - id_servers = set(self.config.trusted_third_party_id_servers) - - def _bg_user_threepids_grandfather_txn(txn): - sql = """ - INSERT INTO user_threepid_id_server - (user_id, medium, address, id_server) - SELECT user_id, medium, address, ? - FROM user_threepids - """ - - txn.executemany(sql, [(id_server,) for id_server in id_servers]) - - if id_servers: - yield self.db.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn - ) - - yield self.db.updates._end_background_update("user_threepids_grandfather") - - return 1 - - -class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): - super(RegistrationStore, self).__init__(database, db_conn, hs) - - self._account_validity = hs.config.account_validity - - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) - - # Create a background job for culling expired 3PID validity tokens - def start_cull(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "cull_expired_threepid_validation_tokens", - self.cull_expired_threepid_validation_tokens, - ) - - hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - - @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): - """Adds an access token for the given user. - - Args: - user_id (str): The user ID. - token (str): The new access token to add. - device_id (str): ID of the device to associate with the access - token - valid_until_ms (int|None): when the token is valid until. None for - no expiry. - Raises: - StoreError if there was a problem adding this. - """ - next_id = self._access_tokens_id_gen.get_next() - - yield self.db.simple_insert( - "access_tokens", - { - "id": next_id, - "user_id": user_id, - "token": token, - "device_id": device_id, - "valid_until_ms": valid_until_ms, - }, - desc="add_access_token_to_user", - ) - - def register_user( - self, - user_id, - password_hash=None, - was_guest=False, - make_guest=False, - appservice_id=None, - create_profile_with_displayname=None, - admin=False, - user_type=None, - ): - """Attempts to register an account. - - Args: - user_id (str): The desired user ID to register. - password_hash (str|None): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being - upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, - false to add a regular user account. - appservice_id (str): The ID of the appservice registering the user. - create_profile_with_displayname (unicode): Optionally create a profile for - the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from - api.constants.UserTypes, or None for a normal user. - - Raises: - StoreError if the user_id could not be registered. - - Returns: - Deferred - """ - return self.db.runInteraction( - "register_user", - self._register_user, - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - ) - - def _register_user( - self, - txn, - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - ): - user_id_obj = UserID.from_string(user_id) - - now = int(self.clock.time()) - - try: - if was_guest: - # Ensure that the guest user actually exists - # ``allow_none=False`` makes this raise an exception - # if the row isn't in the database. - self.db.simple_select_one_txn( - txn, - "users", - keyvalues={"name": user_id, "is_guest": 1}, - retcols=("name",), - allow_none=False, - ) - - self.db.simple_update_one_txn( - txn, - "users", - keyvalues={"name": user_id, "is_guest": 1}, - updatevalues={ - "password_hash": password_hash, - "upgrade_ts": now, - "is_guest": 1 if make_guest else 0, - "appservice_id": appservice_id, - "admin": 1 if admin else 0, - "user_type": user_type, - }, - ) - else: - self.db.simple_insert_txn( - txn, - "users", - values={ - "name": user_id, - "password_hash": password_hash, - "creation_ts": now, - "is_guest": 1 if make_guest else 0, - "appservice_id": appservice_id, - "admin": 1 if admin else 0, - "user_type": user_type, - }, - ) - - except self.database_engine.module.IntegrityError: - raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) - - if self._account_validity.enabled: - self.set_expiration_date_for_user_txn(txn, user_id) - - if create_profile_with_displayname: - # set a default displayname serverside to avoid ugly race - # between auto-joins and clients trying to set displaynames - # - # *obviously* the 'profiles' table uses localpart for user_id - # while everything else uses the full mxid. - txn.execute( - "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", - (user_id_obj.localpart, create_profile_with_displayname), - ) - - if self.hs.config.stats_enabled: - # we create a new completed user statistics row - - # we don't strictly need current_token since this user really can't - # have any state deltas before now (as it is a new user), but still, - # we include it for completeness. - current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) - self._update_stats_delta_txn( - txn, now, "user", user_id, {}, complete_with_stream_id=current_token - ) - - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - txn.call_after(self.is_guest.invalidate, (user_id,)) - - def record_user_external_id( - self, auth_provider: str, external_id: str, user_id: str - ) -> Deferred: - """Record a mapping from an external user id to a mxid - - Args: - auth_provider: identifier for the remote auth provider - external_id: id on that system - user_id: complete mxid that it is mapped to - """ - return self.db.simple_insert( - table="user_external_ids", - values={ - "auth_provider": auth_provider, - "external_id": external_id, - "user_id": user_id, - }, - desc="record_user_external_id", - ) - - def user_set_password_hash(self, user_id, password_hash): - """ - NB. This does *not* evict any cache because the one use for this - removes most of the entries subsequently anyway so it would be - pointless. Use flush_user separately. - """ - - def user_set_password_hash_txn(txn): - self.db.simple_update_one_txn( - txn, "users", {"name": user_id}, {"password_hash": password_hash} - ) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - - return self.db.runInteraction( - "user_set_password_hash", user_set_password_hash_txn - ) - - def user_set_consent_version(self, user_id, consent_version): - """Updates the user table to record privacy policy consent - - Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy the user has consented - to - - Raises: - StoreError(404) if user not found - """ - - def f(txn): - self.db.simple_update_one_txn( - txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"consent_version": consent_version}, - ) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - - return self.db.runInteraction("user_set_consent_version", f) - - def user_set_consent_server_notice_sent(self, user_id, consent_version): - """Updates the user table to record that we have sent the user a server - notice about privacy policy consent - - Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy we have notified the - user about - - Raises: - StoreError(404) if user not found - """ - - def f(txn): - self.db.simple_update_one_txn( - txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"consent_server_notice_sent": consent_version}, - ) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - - return self.db.runInteraction("user_set_consent_server_notice_sent", f) - - def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): - """ - Invalidate access tokens belonging to a user - - Args: - user_id (str): ID of user the tokens belong to - except_token_id (str): list of access_tokens IDs which should - *not* be deleted - device_id (str|None): ID of device the tokens are associated with. - If None, tokens associated with any device (or no device) will - be deleted - Returns: - defer.Deferred[list[str, int, str|None, int]]: a list of - (token, token id, device id) for each of the deleted tokens - """ - - def f(txn): - keyvalues = {"user_id": user_id} - if device_id is not None: - keyvalues["device_id"] = device_id - - items = keyvalues.items() - where_clause = " AND ".join(k + " = ?" for k, _ in items) - values = [v for _, v in items] - if except_token_id: - where_clause += " AND id != ?" - values.append(except_token_id) - - txn.execute( - "SELECT token, id, device_id FROM access_tokens WHERE %s" - % where_clause, - values, - ) - tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] - - for token, _, _ in tokens_and_devices: - self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (token,) - ) - - txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) - - return tokens_and_devices - - return self.db.runInteraction("user_delete_access_tokens", f) - - def delete_access_token(self, access_token): - def f(txn): - self.db.simple_delete_one_txn( - txn, table="access_tokens", keyvalues={"token": access_token} - ) - - self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (access_token,) - ) - - return self.db.runInteraction("delete_access_token", f) - - @cachedInlineCallbacks() - def is_guest(self, user_id): - res = yield self.db.simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="is_guest", - allow_none=True, - desc="is_guest", - ) - - return res if res else False - - def add_user_pending_deactivation(self, user_id): - """ - Adds a user to the table of users who need to be parted from all the rooms they're - in - """ - return self.db.simple_insert( - "users_pending_deactivation", - values={"user_id": user_id}, - desc="add_user_pending_deactivation", - ) - - def del_user_pending_deactivation(self, user_id): - """ - Removes the given user to the table of users who need to be parted from all the - rooms they're in, effectively marking that user as fully deactivated. - """ - # XXX: This should be simple_delete_one but we failed to put a unique index on - # the table, so somehow duplicate entries have ended up in it. - return self.db.simple_delete( - "users_pending_deactivation", - keyvalues={"user_id": user_id}, - desc="del_user_pending_deactivation", - ) - - def get_user_pending_deactivation(self): - """ - Gets one user from the table of users waiting to be parted from all the rooms - they're in. - """ - return self.db.simple_select_one_onecol( - "users_pending_deactivation", - keyvalues={}, - retcol="user_id", - allow_none=True, - desc="get_users_pending_deactivation", - ) - - def validate_threepid_session(self, session_id, client_secret, token, current_ts): - """Attempt to validate a threepid session using a token - - Args: - session_id (str): The id of a validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - token (str): A validation token - current_ts (int): The current unix time in milliseconds. Used for - checking token expiry status - - Raises: - ThreepidValidationError: if a matching validation token was not found or has - expired - - Returns: - deferred str|None: A str representing a link to redirect the user - to if there is one. - """ - - # Insert everything into a transaction in order to run atomically - def validate_threepid_session_txn(txn): - row = self.db.simple_select_one_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - retcols=["client_secret", "validated_at"], - allow_none=True, - ) - - if not row: - raise ThreepidValidationError(400, "Unknown session_id") - retrieved_client_secret = row["client_secret"] - validated_at = row["validated_at"] - - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" - ) - - row = self.db.simple_select_one_txn( - txn, - table="threepid_validation_token", - keyvalues={"session_id": session_id, "token": token}, - retcols=["expires", "next_link"], - allow_none=True, - ) - - if not row: - raise ThreepidValidationError( - 400, "Validation token not found or has expired" - ) - expires = row["expires"] - next_link = row["next_link"] - - # If the session is already validated, no need to revalidate - if validated_at: - return next_link - - if expires <= current_ts: - raise ThreepidValidationError( - 400, "This token has expired. Please request a new one" - ) - - # Looks good. Validate the session - self.db.simple_update_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self.clock.time_msec()}, - ) - - return next_link - - # Return next_link if it exists - return self.db.runInteraction( - "validate_threepid_session_txn", validate_threepid_session_txn - ) - - def upsert_threepid_validation_session( - self, - medium, - address, - client_secret, - send_attempt, - session_id, - validated_at=None, - ): - """Upsert a threepid validation session - Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - session_id (str): The id of this validation session - validated_at (int|None): The unix timestamp in milliseconds of - when the session was marked as valid - """ - insertion_values = { - "medium": medium, - "address": address, - "client_secret": client_secret, - } - - if validated_at: - insertion_values["validated_at"] = validated_at - - return self.db.simple_upsert( - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - values={"last_send_attempt": send_attempt}, - insertion_values=insertion_values, - desc="upsert_threepid_validation_session", - ) - - def start_or_continue_validation_session( - self, - medium, - address, - session_id, - client_secret, - send_attempt, - next_link, - token, - token_expires, - ): - """Creates a new threepid validation session if it does not already - exist and associates a new validation token with it - - Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - session_id (str): The id of this validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - next_link (str|None): The link to redirect the user to upon - successful validation - token (str): The validation token - token_expires (int): The timestamp for which after the token - will no longer be valid - """ - - def start_or_continue_validation_session_txn(txn): - # Create or update a validation session - self.db.simple_upsert_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - values={"last_send_attempt": send_attempt}, - insertion_values={ - "medium": medium, - "address": address, - "client_secret": client_secret, - }, - ) - - # Create a new validation token with this session ID - self.db.simple_insert_txn( - txn, - table="threepid_validation_token", - values={ - "session_id": session_id, - "token": token, - "next_link": next_link, - "expires": token_expires, - }, - ) - - return self.db.runInteraction( - "start_or_continue_validation_session", - start_or_continue_validation_session_txn, - ) - - def cull_expired_threepid_validation_tokens(self): - """Remove threepid validation tokens with expiry dates that have passed""" - - def cull_expired_threepid_validation_tokens_txn(txn, ts): - sql = """ - DELETE FROM threepid_validation_token WHERE - expires < ? - """ - return txn.execute(sql, (ts,)) - - return self.db.runInteraction( - "cull_expired_threepid_validation_tokens", - cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), - ) - - @defer.inlineCallbacks - def set_user_deactivated_status(self, user_id, deactivated): - """Set the `deactivated` property for the provided user to the provided value. - - Args: - user_id (str): The ID of the user to set the status for. - deactivated (bool): The value to set for `deactivated`. - """ - - yield self.db.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, - ) - - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self.db.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) - - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.db.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - yield self.db.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self.db.simple_upsert_txn( - txn, - "account_validity", - keyvalues={"user_id": user_id}, - values={"expiration_ts_ms": expiration_ts, "email_sent": False}, - ) - - -def find_max_generated_user_id_localpart(cur: Cursor) -> int: - """ - Gets the localpart of the max current generated user ID. - - Generated user IDs are integers, so we find the largest integer user ID - already taken and return that. - """ - - # We bound between '@0' and '@a' to avoid pulling the entire table - # out. - cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") - - regex = re.compile(r"^@(\d+):") - - max_found = 0 - - for (user_id,) in cur: - match = regex.search(user_id) - if match: - max_found = max(int(match.group(1)), max_found) - return max_found diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py deleted file mode 100644 index 27e5a2084a..0000000000 --- a/synapse/storage/data_stores/main/rejections.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage._base import SQLBaseStore - -logger = logging.getLogger(__name__) - - -class RejectionsStore(SQLBaseStore): - def get_rejection_reason(self, event_id): - return self.db.simple_select_one_onecol( - table="rejections", - retcol="reason", - keyvalues={"event_id": event_id}, - allow_none=True, - desc="get_rejection_reason", - ) diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py deleted file mode 100644 index 7d477f8d01..0000000000 --- a/synapse/storage/data_stores/main/relations.py +++ /dev/null @@ -1,327 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -import attr - -from synapse.api.constants import RelationTypes -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.stream import generate_pagination_where_clause -from synapse.storage.relations import ( - AggregationPaginationToken, - PaginationChunk, - RelationPaginationToken, -) -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks - -logger = logging.getLogger(__name__) - - -class RelationsWorkerStore(SQLBaseStore): - @cached(tree=True) - def get_relations_for_event( - self, - event_id, - relation_type=None, - event_type=None, - aggregation_key=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): - """Get a list of relations for an event, ordered by topological ordering. - - Args: - event_id (str): Fetch events that relate to this event ID. - relation_type (str|None): Only fetch events with this relation - type, if given. - event_type (str|None): Only fetch events with this event type, if - given. - aggregation_key (str|None): Only fetch events with this aggregation - key, if given. - limit (int): Only fetch the most recent `limit` events. - direction (str): Whether to fetch the most recent first (`"b"`) or - the oldest first (`"f"`). - from_token (RelationPaginationToken|None): Fetch rows from the given - token, or from the start if None. - to_token (RelationPaginationToken|None): Fetch rows up to the given - token, or up to the end if None. - - Returns: - Deferred[PaginationChunk]: List of event IDs that match relations - requested. The rows are of the form `{"event_id": "..."}`. - """ - - where_clause = ["relates_to_id = ?"] - where_args = [event_id] - - if relation_type is not None: - where_clause.append("relation_type = ?") - where_args.append(relation_type) - - if event_type is not None: - where_clause.append("type = ?") - where_args.append(event_type) - - if aggregation_key: - where_clause.append("aggregation_key = ?") - where_args.append(aggregation_key) - - pagination_clause = generate_pagination_where_clause( - direction=direction, - column_names=("topological_ordering", "stream_ordering"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, - engine=self.database_engine, - ) - - if pagination_clause: - where_clause.append(pagination_clause) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - sql = """ - SELECT event_id, topological_ordering, stream_ordering - FROM event_relations - INNER JOIN events USING (event_id) - WHERE %s - ORDER BY topological_ordering %s, stream_ordering %s - LIMIT ? - """ % ( - " AND ".join(where_clause), - order, - order, - ) - - def _get_recent_references_for_event_txn(txn): - txn.execute(sql, where_args + [limit + 1]) - - last_topo_id = None - last_stream_id = None - events = [] - for row in txn: - events.append({"event_id": row[0]}) - last_topo_id = row[1] - last_stream_id = row[2] - - next_batch = None - if len(events) > limit and last_topo_id and last_stream_id: - next_batch = RelationPaginationToken(last_topo_id, last_stream_id) - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) - - return self.db.runInteraction( - "get_recent_references_for_event", _get_recent_references_for_event_txn - ) - - @cached(tree=True) - def get_aggregation_groups_for_event( - self, - event_id, - event_type=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): - """Get a list of annotations on the event, grouped by event type and - aggregation key, sorted by count. - - This is used e.g. to get the what and how many reactions have happend - on an event. - - Args: - event_id (str): Fetch events that relate to this event ID. - event_type (str|None): Only fetch events with this event type, if - given. - limit (int): Only fetch the `limit` groups. - direction (str): Whether to fetch the highest count first (`"b"`) or - the lowest count first (`"f"`). - from_token (AggregationPaginationToken|None): Fetch rows from the - given token, or from the start if None. - to_token (AggregationPaginationToken|None): Fetch rows up to the - given token, or up to the end if None. - - - Returns: - Deferred[PaginationChunk]: List of groups of annotations that - match. Each row is a dict with `type`, `key` and `count` fields. - """ - - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args = [event_id, RelationTypes.ANNOTATION] - - if event_type: - where_clause.append("type = ?") - where_args.append(event_type) - - having_clause = generate_pagination_where_clause( - direction=direction, - column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, - engine=self.database_engine, - ) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - if having_clause: - having_clause = "HAVING " + having_clause - else: - having_clause = "" - - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE {where_clause} - GROUP BY relation_type, type, aggregation_key - {having_clause} - ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} - LIMIT ? - """.format( - where_clause=" AND ".join(where_clause), - order=order, - having_clause=having_clause, - ) - - def _get_aggregation_groups_for_event_txn(txn): - txn.execute(sql, where_args + [limit + 1]) - - next_batch = None - events = [] - for row in txn: - events.append({"type": row[0], "key": row[1], "count": row[2]}) - next_batch = AggregationPaginationToken(row[2], row[3]) - - if len(events) <= limit: - next_batch = None - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) - - return self.db.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn - ) - - @cachedInlineCallbacks() - def get_applicable_edit(self, event_id): - """Get the most recent edit (if any) that has happened for the given - event. - - Correctly handles checking whether edits were allowed to happen. - - Args: - event_id (str): The original event ID - - Returns: - Deferred[EventBase|None]: Returns the most recent edit, if any. - """ - - # We only allow edits for `m.room.message` events that have the same sender - # and event type. We can't assert these things during regular event auth so - # we have to do the checks post hoc. - - # Fetches latest edit that has the same type and sender as the - # original, and is an `m.room.message`. - sql = """ - SELECT edit.event_id FROM events AS edit - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS original ON - original.event_id = relates_to_id - AND edit.type = original.type - AND edit.sender = original.sender - WHERE - relates_to_id = ? - AND relation_type = ? - AND edit.type = 'm.room.message' - ORDER by edit.origin_server_ts DESC, edit.event_id DESC - LIMIT 1 - """ - - def _get_applicable_edit_txn(txn): - txn.execute(sql, (event_id, RelationTypes.REPLACE)) - row = txn.fetchone() - if row: - return row[0] - - edit_id = yield self.db.runInteraction( - "get_applicable_edit", _get_applicable_edit_txn - ) - - if not edit_id: - return - - edit_event = yield self.get_event(edit_id, allow_none=True) - return edit_event - - def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): - """Check if a user has already annotated an event with the same key - (e.g. already liked an event). - - Args: - parent_id (str): The event being annotated - event_type (str): The event type of the annotation - aggregation_key (str): The aggregation key of the annotation - sender (str): The sender of the annotation - - Returns: - Deferred[bool] - """ - - sql = """ - SELECT 1 FROM event_relations - INNER JOIN events USING (event_id) - WHERE - relates_to_id = ? - AND relation_type = ? - AND type = ? - AND sender = ? - AND aggregation_key = ? - LIMIT 1; - """ - - def _get_if_user_has_annotated_event(txn): - txn.execute( - sql, - ( - parent_id, - RelationTypes.ANNOTATION, - event_type, - sender, - aggregation_key, - ), - ) - - return bool(txn.fetchone()) - - return self.db.runInteraction( - "get_if_user_has_annotated_event", _get_if_user_has_annotated_event - ) - - -class RelationsStore(RelationsWorkerStore): - pass diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py deleted file mode 100644 index ab48052cdc..0000000000 --- a/synapse/storage/data_stores/main/room.py +++ /dev/null @@ -1,1425 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import logging -import re -from abc import abstractmethod -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple - -from canonicaljson import json - -from synapse.api.constants import EventTypes -from synapse.api.errors import StoreError -from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.data_stores.main.search import SearchStore -from synapse.storage.database import Database, LoggingTransaction -from synapse.types import ThirdPartyInstanceID -from synapse.util.caches.descriptors import cached - -logger = logging.getLogger(__name__) - - -OpsLevel = collections.namedtuple( - "OpsLevel", ("ban_level", "kick_level", "redact_level") -) - -RatelimitOverride = collections.namedtuple( - "RatelimitOverride", ("messages_per_second", "burst_count") -) - - -class RoomSortOrder(Enum): - """ - Enum to define the sorting method used when returning rooms with get_rooms_paginate - - NAME = sort rooms alphabetically by name - JOINED_MEMBERS = sort rooms by membership size, highest to lowest - """ - - # ALPHABETICAL and SIZE are deprecated. - # ALPHABETICAL is the same as NAME. - ALPHABETICAL = "alphabetical" - # SIZE is the same as JOINED_MEMBERS. - SIZE = "size" - NAME = "name" - CANONICAL_ALIAS = "canonical_alias" - JOINED_MEMBERS = "joined_members" - JOINED_LOCAL_MEMBERS = "joined_local_members" - VERSION = "version" - CREATOR = "creator" - ENCRYPTION = "encryption" - FEDERATABLE = "federatable" - PUBLIC = "public" - JOIN_RULES = "join_rules" - GUEST_ACCESS = "guest_access" - HISTORY_VISIBILITY = "history_visibility" - STATE_EVENTS = "state_events" - - -class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - super(RoomWorkerStore, self).__init__(database, db_conn, hs) - - self.config = hs.config - - def get_room(self, room_id): - """Retrieve a room. - - Args: - room_id (str): The ID of the room to retrieve. - Returns: - A dict containing the room information, or None if the room is unknown. - """ - return self.db.simple_select_one( - table="rooms", - keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator"), - desc="get_room", - allow_none=True, - ) - - def get_room_with_stats(self, room_id: str): - """Retrieve room with statistics. - - Args: - room_id: The ID of the room to retrieve. - Returns: - A dict containing the room information, or None if the room is unknown. - """ - - def get_room_with_stats_txn(txn, room_id): - sql = """ - SELECT room_id, state.name, state.canonical_alias, curr.joined_members, - curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, - rooms.creator, state.encryption, state.is_federatable AS federatable, - rooms.is_public AS public, state.join_rules, state.guest_access, - state.history_visibility, curr.current_state_events AS state_events - FROM rooms - LEFT JOIN room_stats_state state USING (room_id) - LEFT JOIN room_stats_current curr USING (room_id) - WHERE room_id = ? - """ - txn.execute(sql, [room_id]) - # Catch error if sql returns empty result to return "None" instead of an error - try: - res = self.db.cursor_to_dict(txn)[0] - except IndexError: - return None - - res["federatable"] = bool(res["federatable"]) - res["public"] = bool(res["public"]) - return res - - return self.db.runInteraction( - "get_room_with_stats", get_room_with_stats_txn, room_id - ) - - def get_public_room_ids(self): - return self.db.simple_select_onecol( - table="rooms", - keyvalues={"is_public": True}, - retcol="room_id", - desc="get_public_room_ids", - ) - - def count_public_rooms(self, network_tuple, ignore_non_federatable): - """Counts the number of public rooms as tracked in the room_stats_current - and room_stats_state table. - - Args: - network_tuple (ThirdPartyInstanceID|None) - ignore_non_federatable (bool): If true filters out non-federatable rooms - """ - - def _count_public_rooms_txn(txn): - query_args = [] - - if network_tuple: - if network_tuple.appservice_id: - published_sql = """ - SELECT room_id from appservice_room_list - WHERE appservice_id = ? AND network_id = ? - """ - query_args.append(network_tuple.appservice_id) - query_args.append(network_tuple.network_id) - else: - published_sql = """ - SELECT room_id FROM rooms WHERE is_public - """ - else: - published_sql = """ - SELECT room_id FROM rooms WHERE is_public - UNION SELECT room_id from appservice_room_list - """ - - sql = """ - SELECT - COALESCE(COUNT(*), 0) - FROM ( - %(published_sql)s - ) published - INNER JOIN room_stats_state USING (room_id) - INNER JOIN room_stats_current USING (room_id) - WHERE - ( - join_rules = 'public' OR history_visibility = 'world_readable' - ) - AND joined_members > 0 - """ % { - "published_sql": published_sql - } - - txn.execute(sql, query_args) - return txn.fetchone()[0] - - return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) - - async def get_largest_public_rooms( - self, - network_tuple: Optional[ThirdPartyInstanceID], - search_filter: Optional[dict], - limit: Optional[int], - bounds: Optional[Tuple[int, str]], - forwards: bool, - ignore_non_federatable: bool = False, - ): - """Gets the largest public rooms (where largest is in terms of joined - members, as tracked in the statistics table). - - Args: - network_tuple - search_filter - limit: Maxmimum number of rows to return, unlimited otherwise. - bounds: An uppoer or lower bound to apply to result set if given, - consists of a joined member count and room_id (these are - excluded from result set). - forwards: true iff going forwards, going backwards otherwise - ignore_non_federatable: If true filters out non-federatable rooms. - - Returns: - Rooms in order: biggest number of joined users first. - We then arbitrarily use the room_id as a tie breaker. - - """ - - where_clauses = [] - query_args = [] - - if network_tuple: - if network_tuple.appservice_id: - published_sql = """ - SELECT room_id from appservice_room_list - WHERE appservice_id = ? AND network_id = ? - """ - query_args.append(network_tuple.appservice_id) - query_args.append(network_tuple.network_id) - else: - published_sql = """ - SELECT room_id FROM rooms WHERE is_public - """ - else: - published_sql = """ - SELECT room_id FROM rooms WHERE is_public - UNION SELECT room_id from appservice_room_list - """ - - # Work out the bounds if we're given them, these bounds look slightly - # odd, but are designed to help query planner use indices by pulling - # out a common bound. - if bounds: - last_joined_members, last_room_id = bounds - if forwards: - where_clauses.append( - """ - joined_members <= ? AND ( - joined_members < ? OR room_id < ? - ) - """ - ) - else: - where_clauses.append( - """ - joined_members >= ? AND ( - joined_members > ? OR room_id > ? - ) - """ - ) - - query_args += [last_joined_members, last_joined_members, last_room_id] - - if ignore_non_federatable: - where_clauses.append("is_federatable") - - if search_filter and search_filter.get("generic_search_term", None): - search_term = "%" + search_filter["generic_search_term"] + "%" - - where_clauses.append( - """ - ( - LOWER(name) LIKE ? - OR LOWER(topic) LIKE ? - OR LOWER(canonical_alias) LIKE ? - ) - """ - ) - query_args += [ - search_term.lower(), - search_term.lower(), - search_term.lower(), - ] - - where_clause = "" - if where_clauses: - where_clause = " AND " + " AND ".join(where_clauses) - - sql = """ - SELECT - room_id, name, topic, canonical_alias, joined_members, - avatar, history_visibility, joined_members, guest_access - FROM ( - %(published_sql)s - ) published - INNER JOIN room_stats_state USING (room_id) - INNER JOIN room_stats_current USING (room_id) - WHERE - ( - join_rules = 'public' OR history_visibility = 'world_readable' - ) - AND joined_members > 0 - %(where_clause)s - ORDER BY joined_members %(dir)s, room_id %(dir)s - """ % { - "published_sql": published_sql, - "where_clause": where_clause, - "dir": "DESC" if forwards else "ASC", - } - - if limit is not None: - query_args.append(limit) - - sql += """ - LIMIT ? - """ - - def _get_largest_public_rooms_txn(txn): - txn.execute(sql, query_args) - - results = self.db.cursor_to_dict(txn) - - if not forwards: - results.reverse() - - return results - - ret_val = await self.db.runInteraction( - "get_largest_public_rooms", _get_largest_public_rooms_txn - ) - return ret_val - - @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self.db.simple_select_one_onecol( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - retcol="1", - allow_none=True, - desc="is_room_blocked", - ) - - async def get_rooms_paginate( - self, - start: int, - limit: int, - order_by: RoomSortOrder, - reverse_order: bool, - search_term: Optional[str], - ) -> Tuple[List[Dict[str, Any]], int]: - """Function to retrieve a paginated list of rooms as json. - - Args: - start: offset in the list - limit: maximum amount of rooms to retrieve - order_by: the sort order of the returned list - reverse_order: whether to reverse the room list - search_term: a string to filter room names by - Returns: - A list of room dicts and an integer representing the total number of - rooms that exist given this query - """ - # Filter room names by a string - where_statement = "" - if search_term: - where_statement = "WHERE state.name LIKE ?" - - # Our postgres db driver converts ? -> %s in SQL strings as that's the - # placeholder for postgres. - # HOWEVER, if you put a % into your SQL then everything goes wibbly. - # To get around this, we're going to surround search_term with %'s - # before giving it to the database in python instead - search_term = "%" + search_term + "%" - - # Set ordering - if RoomSortOrder(order_by) == RoomSortOrder.SIZE: - # Deprecated in favour of RoomSortOrder.JOINED_MEMBERS - order_by_column = "curr.joined_members" - order_by_asc = False - elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL: - # Deprecated in favour of RoomSortOrder.NAME - order_by_column = "state.name" - order_by_asc = True - elif RoomSortOrder(order_by) == RoomSortOrder.NAME: - order_by_column = "state.name" - order_by_asc = True - elif RoomSortOrder(order_by) == RoomSortOrder.CANONICAL_ALIAS: - order_by_column = "state.canonical_alias" - order_by_asc = True - elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_MEMBERS: - order_by_column = "curr.joined_members" - order_by_asc = False - elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_LOCAL_MEMBERS: - order_by_column = "curr.local_users_in_room" - order_by_asc = False - elif RoomSortOrder(order_by) == RoomSortOrder.VERSION: - order_by_column = "rooms.room_version" - order_by_asc = False - elif RoomSortOrder(order_by) == RoomSortOrder.CREATOR: - order_by_column = "rooms.creator" - order_by_asc = True - elif RoomSortOrder(order_by) == RoomSortOrder.ENCRYPTION: - order_by_column = "state.encryption" - order_by_asc = True - elif RoomSortOrder(order_by) == RoomSortOrder.FEDERATABLE: - order_by_column = "state.is_federatable" - order_by_asc = True - elif RoomSortOrder(order_by) == RoomSortOrder.PUBLIC: - order_by_column = "rooms.is_public" - order_by_asc = True - elif RoomSortOrder(order_by) == RoomSortOrder.JOIN_RULES: - order_by_column = "state.join_rules" - order_by_asc = True - elif RoomSortOrder(order_by) == RoomSortOrder.GUEST_ACCESS: - order_by_column = "state.guest_access" - order_by_asc = True - elif RoomSortOrder(order_by) == RoomSortOrder.HISTORY_VISIBILITY: - order_by_column = "state.history_visibility" - order_by_asc = True - elif RoomSortOrder(order_by) == RoomSortOrder.STATE_EVENTS: - order_by_column = "curr.current_state_events" - order_by_asc = False - else: - raise StoreError( - 500, "Incorrect value for order_by provided: %s" % order_by - ) - - # Whether to return the list in reverse order - if reverse_order: - # Flip the boolean - order_by_asc = not order_by_asc - - # Create one query for getting the limited number of events that the user asked - # for, and another query for getting the total number of events that could be - # returned. Thus allowing us to see if there are more events to paginate through - info_sql = """ - SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members, - curr.local_users_in_room, rooms.room_version, rooms.creator, - state.encryption, state.is_federatable, rooms.is_public, state.join_rules, - state.guest_access, state.history_visibility, curr.current_state_events - FROM room_stats_state state - INNER JOIN room_stats_current curr USING (room_id) - INNER JOIN rooms USING (room_id) - %s - ORDER BY %s %s - LIMIT ? - OFFSET ? - """ % ( - where_statement, - order_by_column, - "ASC" if order_by_asc else "DESC", - ) - - # Use a nested SELECT statement as SQL can't count(*) with an OFFSET - count_sql = """ - SELECT count(*) FROM ( - SELECT room_id FROM room_stats_state state - %s - ) AS get_room_ids - """ % ( - where_statement, - ) - - def _get_rooms_paginate_txn(txn): - # Execute the data query - sql_values = (limit, start) - if search_term: - # Add the search term into the WHERE clause - sql_values = (search_term,) + sql_values - txn.execute(info_sql, sql_values) - - # Refactor room query data into a structured dictionary - rooms = [] - for room in txn: - rooms.append( - { - "room_id": room[0], - "name": room[1], - "canonical_alias": room[2], - "joined_members": room[3], - "joined_local_members": room[4], - "version": room[5], - "creator": room[6], - "encryption": room[7], - "federatable": room[8], - "public": room[9], - "join_rules": room[10], - "guest_access": room[11], - "history_visibility": room[12], - "state_events": room[13], - } - ) - - # Execute the count query - - # Add the search term into the WHERE clause if present - sql_values = (search_term,) if search_term else () - txn.execute(count_sql, sql_values) - - room_count = txn.fetchone() - return rooms, room_count[0] - - return await self.db.runInteraction( - "get_rooms_paginate", _get_rooms_paginate_txn, - ) - - @cached(max_entries=10000) - async def get_ratelimit_for_user(self, user_id): - """Check if there are any overrides for ratelimiting for the given - user - - Args: - user_id (str) - - Returns: - RatelimitOverride if there is an override, else None. If the contents - of RatelimitOverride are None or 0 then ratelimitng has been - disabled for that user entirely. - """ - row = await self.db.simple_select_one( - table="ratelimit_override", - keyvalues={"user_id": user_id}, - retcols=("messages_per_second", "burst_count"), - allow_none=True, - desc="get_ratelimit_for_user", - ) - - if row: - return RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - ) - else: - return None - - @cached() - async def get_retention_policy_for_room(self, room_id): - """Get the retention policy for a given room. - - If no retention policy has been found for this room, returns a policy defined - by the configured default policy (which has None as both the 'min_lifetime' and - the 'max_lifetime' if no default policy has been defined in the server's - configuration). - - Args: - room_id (str): The ID of the room to get the retention policy of. - - Returns: - dict[int, int]: "min_lifetime" and "max_lifetime" for this room. - """ - - def get_retention_policy_for_room_txn(txn): - txn.execute( - """ - SELECT min_lifetime, max_lifetime FROM room_retention - INNER JOIN current_state_events USING (event_id, room_id) - WHERE room_id = ?; - """, - (room_id,), - ) - - return self.db.cursor_to_dict(txn) - - ret = await self.db.runInteraction( - "get_retention_policy_for_room", get_retention_policy_for_room_txn, - ) - - # If we don't know this room ID, ret will be None, in this case return the default - # policy. - if not ret: - return { - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, - } - - row = ret[0] - - # If one of the room's policy's attributes isn't defined, use the matching - # attribute from the default policy. - # The default values will be None if no default policy has been defined, or if one - # of the attributes is missing from the default policy. - if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.retention_default_min_lifetime - - if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.retention_default_max_lifetime - - return row - - def get_media_mxcs_in_room(self, room_id): - """Retrieves all the local and remote media MXC URIs in a given room - - Args: - room_id (str) - - Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. - """ - - def _get_media_mxcs_in_room_txn(txn): - local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) - local_media_mxcs = [] - remote_media_mxcs = [] - - # Convert the IDs to MXC URIs - for media_id in local_mxcs: - local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) - for hostname, media_id in remote_mxcs: - remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) - - return local_media_mxcs, remote_media_mxcs - - return self.db.runInteraction( - "get_media_ids_in_room", _get_media_mxcs_in_room_txn - ) - - def quarantine_media_ids_in_room(self, room_id, quarantined_by): - """For a room loops through all events with media and quarantines - the associated media - """ - - logger.info("Quarantining media in room: %s", room_id) - - def _quarantine_media_in_room_txn(txn): - local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) - return self._quarantine_media_txn( - txn, local_mxcs, remote_mxcs, quarantined_by - ) - - return self.db.runInteraction( - "quarantine_media_in_room", _quarantine_media_in_room_txn - ) - - def _get_media_mxcs_in_room_txn(self, txn, room_id): - """Retrieves all the local and remote media MXC URIs in a given room - - Args: - txn (cursor) - room_id (str) - - Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. - """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - - sql = """ - SELECT stream_ordering, json FROM events - JOIN event_json USING (room_id, event_id) - WHERE room_id = ? - %(where_clause)s - AND contains_url = ? AND outlier = ? - ORDER BY stream_ordering DESC - LIMIT ? - """ - txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100)) - - local_media_mxcs = [] - remote_media_mxcs = [] - - while True: - next_token = None - for stream_ordering, content_json in txn: - next_token = stream_ordering - event_json = db_to_json(content_json) - content = event_json["content"] - content_url = content.get("url") - thumbnail_url = content.get("info", {}).get("thumbnail_url") - - for url in (content_url, thumbnail_url): - if not url: - continue - matches = mxc_re.match(url) - if matches: - hostname = matches.group(1) - media_id = matches.group(2) - if hostname == self.hs.hostname: - local_media_mxcs.append(media_id) - else: - remote_media_mxcs.append((hostname, media_id)) - - if next_token is None: - # We've gone through the whole room, so we're finished. - break - - txn.execute( - sql % {"where_clause": "AND stream_ordering < ?"}, - (room_id, next_token, True, False, 100), - ) - - return local_media_mxcs, remote_media_mxcs - - def quarantine_media_by_id( - self, server_name: str, media_id: str, quarantined_by: str, - ): - """quarantines a single local or remote media id - - Args: - server_name: The name of the server that holds this media - media_id: The ID of the media to be quarantined - quarantined_by: The user ID that initiated the quarantine request - """ - logger.info("Quarantining media: %s/%s", server_name, media_id) - is_local = server_name == self.config.server_name - - def _quarantine_media_by_id_txn(txn): - local_mxcs = [media_id] if is_local else [] - remote_mxcs = [(server_name, media_id)] if not is_local else [] - - return self._quarantine_media_txn( - txn, local_mxcs, remote_mxcs, quarantined_by - ) - - return self.db.runInteraction( - "quarantine_media_by_user", _quarantine_media_by_id_txn - ) - - def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): - """quarantines all local media associated with a single user - - Args: - user_id: The ID of the user to quarantine media of - quarantined_by: The ID of the user who made the quarantine request - """ - - def _quarantine_media_by_user_txn(txn): - local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) - return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) - - return self.db.runInteraction( - "quarantine_media_by_user", _quarantine_media_by_user_txn - ) - - def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): - """Retrieves local media IDs by a given user - - Args: - txn (cursor) - user_id: The ID of the user to retrieve media IDs of - - Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. - """ - # Local media - sql = """ - SELECT media_id - FROM local_media_repository - WHERE user_id = ? - """ - if filter_quarantined: - sql += "AND quarantined_by IS NULL" - txn.execute(sql, (user_id,)) - - local_media_ids = [row[0] for row in txn] - - # TODO: Figure out all remote media a user has referenced in a message - - return local_media_ids - - def _quarantine_media_txn( - self, - txn, - local_mxcs: List[str], - remote_mxcs: List[Tuple[str, str]], - quarantined_by: str, - ) -> int: - """Quarantine local and remote media items - - Args: - txn (cursor) - local_mxcs: A list of local mxc URLs - remote_mxcs: A list of (remote server, media id) tuples representing - remote mxc URLs - quarantined_by: The ID of the user who initiated the quarantine request - Returns: - The total number of media items quarantined - """ - # Update all the tables to set the quarantined_by flag - txn.executemany( - """ - UPDATE local_media_repository - SET quarantined_by = ? - WHERE media_id = ? AND safe_from_quarantine = ? - """, - ((quarantined_by, media_id, False) for media_id in local_mxcs), - ) - # Note that a rowcount of -1 can be used to indicate no rows were affected. - total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 - - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), - ) - total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 - - return total_media_quarantined - - async def get_all_new_public_rooms( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for public rooms replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - if last_id == current_id: - return [], current_id, False - - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - - -class RoomBackgroundUpdateStore(SQLBaseStore): - REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" - ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" - - def __init__(self, database: Database, db_conn, hs): - super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) - - self.config = hs.config - - self.db.updates.register_background_update_handler( - "insert_room_retention", self._background_insert_retention, - ) - - self.db.updates.register_background_update_handler( - self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, - self._remove_tombstoned_rooms_from_directory, - ) - - self.db.updates.register_background_update_handler( - self.ADD_ROOMS_ROOM_VERSION_COLUMN, - self._background_add_rooms_room_version_column, - ) - - async def _background_insert_retention(self, progress, batch_size): - """Retrieves a list of all rooms within a range and inserts an entry for each of - them into the room_retention table. - NULLs the property's columns if missing from the retention event in the room's - state (or NULLs all of them if there's no retention event in the room's state), - so that we fall back to the server's retention policy. - """ - - last_room = progress.get("room_id", "") - - def _background_insert_retention_txn(txn): - txn.execute( - """ - SELECT state.room_id, state.event_id, events.json - FROM current_state_events as state - LEFT JOIN event_json AS events ON (state.event_id = events.event_id) - WHERE state.room_id > ? AND state.type = '%s' - ORDER BY state.room_id ASC - LIMIT ?; - """ - % EventTypes.Retention, - (last_room, batch_size), - ) - - rows = self.db.cursor_to_dict(txn) - - if not rows: - return True - - for row in rows: - if not row["json"]: - retention_policy = {} - else: - ev = db_to_json(row["json"]) - retention_policy = ev["content"] - - self.db.simple_insert_txn( - txn=txn, - table="room_retention", - values={ - "room_id": row["room_id"], - "event_id": row["event_id"], - "min_lifetime": retention_policy.get("min_lifetime"), - "max_lifetime": retention_policy.get("max_lifetime"), - }, - ) - - logger.info("Inserted %d rows into room_retention", len(rows)) - - self.db.updates._background_update_progress_txn( - txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} - ) - - if batch_size > len(rows): - return True - else: - return False - - end = await self.db.runInteraction( - "insert_room_retention", _background_insert_retention_txn, - ) - - if end: - await self.db.updates._end_background_update("insert_room_retention") - - return batch_size - - async def _background_add_rooms_room_version_column( - self, progress: dict, batch_size: int - ): - """Background update to go and add room version inforamtion to `rooms` - table from `current_state_events` table. - """ - - last_room_id = progress.get("room_id", "") - - def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): - sql = """ - SELECT room_id, json FROM current_state_events - INNER JOIN event_json USING (room_id, event_id) - WHERE room_id > ? AND type = 'm.room.create' AND state_key = '' - ORDER BY room_id - LIMIT ? - """ - - txn.execute(sql, (last_room_id, batch_size)) - - updates = [] - for room_id, event_json in txn: - event_dict = db_to_json(event_json) - room_version_id = event_dict.get("content", {}).get( - "room_version", RoomVersions.V1.identifier - ) - - creator = event_dict.get("content").get("creator") - - updates.append((room_id, creator, room_version_id)) - - if not updates: - return True - - new_last_room_id = "" - for room_id, creator, room_version_id in updates: - # We upsert here just in case we don't already have a row, - # mainly for paranoia as much badness would happen if we don't - # insert the row and then try and get the room version for the - # room. - self.db.simple_upsert_txn( - txn, - table="rooms", - keyvalues={"room_id": room_id}, - values={"room_version": room_version_id}, - insertion_values={"is_public": False, "creator": creator}, - ) - new_last_room_id = room_id - - self.db.updates._background_update_progress_txn( - txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id} - ) - - return False - - end = await self.db.runInteraction( - "_background_add_rooms_room_version_column", - _background_add_rooms_room_version_column_txn, - ) - - if end: - await self.db.updates._end_background_update( - self.ADD_ROOMS_ROOM_VERSION_COLUMN - ) - - return batch_size - - async def _remove_tombstoned_rooms_from_directory( - self, progress, batch_size - ) -> int: - """Removes any rooms with tombstone events from the room directory - - Nowadays this is handled by the room upgrade handler, but we may have some - that got left behind - """ - - last_room = progress.get("room_id", "") - - def _get_rooms(txn): - txn.execute( - """ - SELECT room_id - FROM rooms r - INNER JOIN current_state_events cse USING (room_id) - WHERE room_id > ? AND r.is_public - AND cse.type = '%s' AND cse.state_key = '' - ORDER BY room_id ASC - LIMIT ?; - """ - % EventTypes.Tombstone, - (last_room, batch_size), - ) - - return [row[0] for row in txn] - - rooms = await self.db.runInteraction( - "get_tombstoned_directory_rooms", _get_rooms - ) - - if not rooms: - await self.db.updates._end_background_update( - self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE - ) - return 0 - - for room_id in rooms: - logger.info("Removing tombstoned room %s from the directory", room_id) - await self.set_room_is_public(room_id, False) - - await self.db.updates._background_update_progress( - self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]} - ) - - return len(rooms) - - @abstractmethod - def set_room_is_public(self, room_id, is_public): - # this will need to be implemented if a background update is performed with - # existing (tombstoned, public) rooms in the database. - # - # It's overridden by RoomStore for the synapse master. - raise NotImplementedError() - - -class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: Database, db_conn, hs): - super(RoomStore, self).__init__(database, db_conn, hs) - - self.config = hs.config - - async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion): - """Ensure that the room is stored in the table - - Called when we join a room over federation, and overwrites any room version - currently in the table. - """ - await self.db.simple_upsert( - desc="upsert_room_on_join", - table="rooms", - keyvalues={"room_id": room_id}, - values={"room_version": room_version.identifier}, - insertion_values={"is_public": False, "creator": ""}, - # rooms has a unique constraint on room_id, so no need to lock when doing an - # emulated upsert. - lock=False, - ) - - async def store_room( - self, - room_id: str, - room_creator_user_id: str, - is_public: bool, - room_version: RoomVersion, - ): - """Stores a room. - - Args: - room_id: The desired room ID, can be None. - room_creator_user_id: The user ID of the room creator. - is_public: True to indicate that this room should appear in - public room lists. - room_version: The version of the room - Raises: - StoreError if the room could not be stored. - """ - try: - - def store_room_txn(txn, next_id): - self.db.simple_insert_txn( - txn, - "rooms", - { - "room_id": room_id, - "creator": room_creator_user_id, - "is_public": is_public, - "room_version": room_version.identifier, - }, - ) - if is_public: - self.db.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - }, - ) - - with self._public_room_id_gen.get_next() as next_id: - await self.db.runInteraction("store_room_txn", store_room_txn, next_id) - except Exception as e: - logger.error("store_room with room_id=%s failed: %s", room_id, e) - raise StoreError(500, "Problem creating room.") - - async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): - """ - When we receive an invite over federation, store the version of the room if we - don't already know the room version. - """ - await self.db.simple_upsert( - desc="maybe_store_room_on_invite", - table="rooms", - keyvalues={"room_id": room_id}, - values={}, - insertion_values={ - "room_version": room_version.identifier, - "is_public": False, - "creator": "", - }, - # rooms has a unique constraint on room_id, so no need to lock when doing an - # emulated upsert. - lock=False, - ) - - async def set_room_is_public(self, room_id, is_public): - def set_room_is_public_txn(txn, next_id): - self.db.simple_update_one_txn( - txn, - table="rooms", - keyvalues={"room_id": room_id}, - updatevalues={"is_public": is_public}, - ) - - entries = self.db.simple_select_list_txn( - txn, - table="public_room_list_stream", - keyvalues={ - "room_id": room_id, - "appservice_id": None, - "network_id": None, - }, - retcols=("stream_id", "visibility"), - ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self.db.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": None, - "network_id": None, - }, - ) - - with self._public_room_id_gen.get_next() as next_id: - await self.db.runInteraction( - "set_room_is_public", set_room_is_public_txn, next_id - ) - self.hs.get_notifier().on_new_replication_data() - - async def set_room_is_public_appservice( - self, room_id, appservice_id, network_id, is_public - ): - """Edit the appservice/network specific public room list. - - Each appservice can have a number of published room lists associated - with them, keyed off of an appservice defined `network_id`, which - basically represents a single instance of a bridge to a third party - network. - - Args: - room_id (str) - appservice_id (str) - network_id (str) - is_public (bool): Whether to publish or unpublish the room from the - list. - """ - - def set_room_is_public_appservice_txn(txn, next_id): - if is_public: - try: - self.db.simple_insert_txn( - txn, - table="appservice_room_list", - values={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - except self.database_engine.module.IntegrityError: - # We've already inserted, nothing to do. - return - else: - self.db.simple_delete_txn( - txn, - table="appservice_room_list", - keyvalues={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - - entries = self.db.simple_select_list_txn( - txn, - table="public_room_list_stream", - keyvalues={ - "room_id": room_id, - "appservice_id": appservice_id, - "network_id": network_id, - }, - retcols=("stream_id", "visibility"), - ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self.db.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": appservice_id, - "network_id": network_id, - }, - ) - - with self._public_room_id_gen.get_next() as next_id: - await self.db.runInteraction( - "set_room_is_public_appservice", - set_room_is_public_appservice_txn, - next_id, - ) - self.hs.get_notifier().on_new_replication_data() - - def get_room_count(self): - """Retrieve a list of all rooms - """ - - def f(txn): - sql = "SELECT count(*) FROM rooms" - txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 - - return self.db.runInteraction("get_rooms", f) - - def add_event_report( - self, room_id, event_id, user_id, reason, content, received_ts - ): - next_id = self._event_reports_id_gen.get_next() - return self.db.simple_insert( - table="event_reports", - values={ - "id": next_id, - "received_ts": received_ts, - "room_id": room_id, - "event_id": event_id, - "user_id": user_id, - "reason": reason, - "content": json.dumps(content), - }, - desc="add_event_report", - ) - - def get_current_public_room_stream_id(self): - return self._public_room_id_gen.get_current_token() - - async def block_room(self, room_id: str, user_id: str) -> None: - """Marks the room as blocked. Can be called multiple times. - - Args: - room_id: Room to block - user_id: Who blocked it - """ - await self.db.simple_upsert( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - values={}, - insertion_values={"user_id": user_id}, - desc="block_room", - ) - await self.db.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - - async def get_rooms_for_retention_period_in_range( - self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False - ) -> Dict[str, dict]: - """Retrieves all of the rooms within the given retention range. - - Optionally includes the rooms which don't have a retention policy. - - Args: - min_ms: Duration in milliseconds that define the lower limit of - the range to handle (exclusive). If None, doesn't set a lower limit. - max_ms: Duration in milliseconds that define the upper limit of - the range to handle (inclusive). If None, doesn't set an upper limit. - include_null: Whether to include rooms which retention policy is NULL - in the returned set. - - Returns: - The rooms within this range, along with their retention - policy. The key is "room_id", and maps to a dict describing the retention - policy associated with this room ID. The keys for this nested dict are - "min_lifetime" (int|None), and "max_lifetime" (int|None). - """ - - def get_rooms_for_retention_period_in_range_txn(txn): - range_conditions = [] - args = [] - - if min_ms is not None: - range_conditions.append("max_lifetime > ?") - args.append(min_ms) - - if max_ms is not None: - range_conditions.append("max_lifetime <= ?") - args.append(max_ms) - - # Do a first query which will retrieve the rooms that have a retention policy - # in their current state. - sql = """ - SELECT room_id, min_lifetime, max_lifetime FROM room_retention - INNER JOIN current_state_events USING (event_id, room_id) - """ - - if len(range_conditions): - sql += " WHERE (" + " AND ".join(range_conditions) + ")" - - if include_null: - sql += " OR max_lifetime IS NULL" - - txn.execute(sql, args) - - rows = self.db.cursor_to_dict(txn) - rooms_dict = {} - - for row in rows: - rooms_dict[row["room_id"]] = { - "min_lifetime": row["min_lifetime"], - "max_lifetime": row["max_lifetime"], - } - - if include_null: - # If required, do a second query that retrieves all of the rooms we know - # of so we can handle rooms with no retention policy. - sql = "SELECT DISTINCT room_id FROM current_state_events" - - txn.execute(sql) - - rows = self.db.cursor_to_dict(txn) - - # If a room isn't already in the dict (i.e. it doesn't have a retention - # policy in its state), add it with a null policy. - for row in rows: - if row["room_id"] not in rooms_dict: - rooms_dict[row["room_id"]] = { - "min_lifetime": None, - "max_lifetime": None, - } - - return rooms_dict - - rooms = await self.db.runInteraction( - "get_rooms_for_retention_period_in_range", - get_rooms_for_retention_period_in_range_txn, - ) - - return rooms diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py deleted file mode 100644 index a92e401e88..0000000000 --- a/synapse/storage/data_stores/main/roommember.py +++ /dev/null @@ -1,1135 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Iterable, List, Set - -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership -from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - LoggingTransaction, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, -) -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database -from synapse.storage.engines import Sqlite3Engine -from synapse.storage.roommember import ( - GetRoomsForUserWithStreamOrdering, - MemberSummary, - ProfileInfo, - RoomsForUser, -) -from synapse.types import Collection, get_domain_from_id -from synapse.util.async_helpers import Linearizer -from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from synapse.util.metrics import Measure - -logger = logging.getLogger(__name__) - - -_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" -_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" - - -class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, database: Database, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) - - # Is the current_state_events.membership up to date? Or is the - # background update still running? - self._current_state_events_membership_up_to_date = False - - txn = LoggingTransaction( - db_conn.cursor(), - name="_check_safe_current_state_events_membership_updated", - database_engine=self.database_engine, - ) - self._check_safe_current_state_events_membership_updated_txn(txn) - txn.close() - - if self.hs.config.metrics_flags.known_servers: - self._known_servers_count = 1 - self.hs.get_clock().looping_call( - run_as_background_process, - 60 * 1000, - "_count_known_servers", - self._count_known_servers, - ) - self.hs.get_clock().call_later( - 1000, - run_as_background_process, - "_count_known_servers", - self._count_known_servers, - ) - LaterGauge( - "synapse_federation_known_servers", - "", - [], - lambda: self._known_servers_count, - ) - - @defer.inlineCallbacks - def _count_known_servers(self): - """ - Count the servers that this server knows about. - - The statistic is stored on the class for the - `synapse_federation_known_servers` LaterGauge to collect. - """ - - def _transact(txn): - if isinstance(self.database_engine, Sqlite3Engine): - query = """ - SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) - FROM ( - SELECT rm.user_id as user_id, instr(rm.user_id, ':') - AS pos FROM room_memberships as rm - INNER JOIN current_state_events as c ON rm.event_id = c.event_id - WHERE c.type = 'm.room.member' - ) as out - """ - else: - query = """ - SELECT COUNT(DISTINCT split_part(state_key, ':', 2)) - FROM current_state_events - WHERE type = 'm.room.member' AND membership = 'join'; - """ - txn.execute(query) - return list(txn)[0][0] - - count = yield self.db.runInteraction("get_known_servers", _transact) - - # We always know about ourselves, even if we have nothing in - # room_memberships (for example, the server is new). - self._known_servers_count = max([count, 1]) - return self._known_servers_count - - def _check_safe_current_state_events_membership_updated_txn(self, txn): - """Checks if it is safe to assume the new current_state_events - membership column is up to date - """ - - pending_update = self.db.simple_select_one_txn( - txn, - table="background_updates", - keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, - retcols=["update_name"], - allow_none=True, - ) - - self._current_state_events_membership_up_to_date = not pending_update - - # If the update is still running, reschedule to run. - if pending_update: - self._clock.call_later( - 15.0, - run_as_background_process, - "_check_safe_current_state_events_membership_updated", - self.db.runInteraction, - "_check_safe_current_state_events_membership_updated", - self._check_safe_current_state_events_membership_updated_txn, - ) - - @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id): - return self.db.runInteraction( - "get_users_in_room", self.get_users_in_room_txn, room_id - ) - - def get_users_in_room_txn(self, txn, room_id): - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - sql = """ - SELECT state_key FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? AND membership = ? - """ - else: - sql = """ - SELECT state_key FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? - """ - - txn.execute(sql, (room_id, Membership.JOIN)) - return [r[0] for r in txn] - - @cached(max_entries=100000) - def get_room_summary(self, room_id): - """ Get the details of a room roughly suitable for use by the room - summary extension to /sync. Useful when lazy loading room members. - Args: - room_id (str): The room ID to query - Returns: - Deferred[dict[str, MemberSummary]: - dict of membership states, pointing to a MemberSummary named tuple. - """ - - def _get_room_summary_txn(txn): - # first get counts. - # We do this all in one transaction to keep the cache small. - # FIXME: get rid of this when we have room_stats - - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT count(*), membership FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - GROUP BY membership - """ - else: - sql = """ - SELECT count(*), m.membership FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - res = {} - for count, membership in txn: - summary = res.setdefault(membership, MemberSummary([], count)) - - # we order by membership and then fairly arbitrarily by event_id so - # heroes are consistent - if self._current_state_events_membership_up_to_date: - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT state_key, membership, event_id - FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - ORDER BY - CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - event_id ASC - LIMIT ? - """ - else: - sql = """ - SELECT c.state_key, m.membership, c.event_id - FROM room_memberships as m - INNER JOIN current_state_events as c USING (room_id, event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - ORDER BY - CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - c.event_id ASC - LIMIT ? - """ - - # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. - txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) - for user_id, membership, event_id in txn: - summary = res[membership] - # we will always have a summary for this membership type at this - # point given the summary currently contains the counts. - members = summary.members - members.append((user_id, event_id)) - - return res - - return self.db.runInteraction("get_room_summary", _get_room_summary_txn) - - def _get_user_counts_in_room_txn(self, txn, room_id): - """ - Get the user count in a room by membership. - - Args: - room_id (str) - membership (Membership) - - Returns: - Deferred[int] - """ - sql = """ - SELECT m.membership, count(*) FROM room_memberships as m - INNER JOIN current_state_events as c USING(event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - return {row[0]: row[1] for row in txn} - - @cached() - def get_invited_rooms_for_local_user(self, user_id): - """ Get all the rooms the *local* user is invited to - - Args: - user_id (str): The user ID. - Returns: - A deferred list of RoomsForUser. - """ - - return self.get_rooms_for_local_user_where_membership_is( - user_id, [Membership.INVITE] - ) - - @defer.inlineCallbacks - def get_invite_for_local_user_in_room(self, user_id, room_id): - """Gets the invite for the given *local* user and room - - Args: - user_id (str) - room_id (str) - - Returns: - Deferred: Resolves to either a RoomsForUser or None if no invite was - found. - """ - invites = yield self.get_invited_rooms_for_local_user(user_id) - for invite in invites: - if invite.room_id == room_id: - return invite - return None - - @defer.inlineCallbacks - def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this *local* user where the membership for this user - matches one in the membership list. - - Filters out forgotten rooms. - - Args: - user_id (str): The user ID. - membership_list (list): A list of synapse.api.constants.Membership - values which the user must be in. - - Returns: - Deferred[list[RoomsForUser]] - """ - if not membership_list: - return defer.succeed(None) - - rooms = yield self.db.runInteraction( - "get_rooms_for_local_user_where_membership_is", - self._get_rooms_for_local_user_where_membership_is_txn, - user_id, - membership_list, - ) - - # Now we filter out forgotten rooms - forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) - return [room for room in rooms if room.room_id not in forgotten_rooms] - - def _get_rooms_for_local_user_where_membership_is_txn( - self, txn, user_id, membership_list - ): - # Paranoia check. - if not self.hs.is_mine_id(user_id): - raise Exception( - "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r" - % (user_id,), - ) - - clause, args = make_in_list_sql_clause( - self.database_engine, "c.membership", membership_list - ) - - sql = """ - SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering - FROM local_current_membership AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - user_id = ? - AND %s - """ % ( - clause, - ) - - txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] - - return results - - @cached(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id): - """Returns a set of room_ids the user is currently joined to. - - If a remote user only returns rooms this server is currently - participating in. - - Args: - user_id (str) - - Returns: - Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns - the rooms the user is in currently, along with the stream ordering - of the most recent join for that user and room. - """ - return self.db.runInteraction( - "get_rooms_for_user_with_stream_ordering", - self._get_rooms_for_user_with_stream_ordering_txn, - user_id, - ) - - def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): - # We use `current_state_events` here and not `local_current_membership` - # as a) this gets called with remote users and b) this only gets called - # for rooms the server is participating in. - if self._current_state_events_membership_up_to_date: - sql = """ - SELECT room_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND c.membership = ? - """ - else: - sql = """ - SELECT room_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND m.membership = ? - """ - - txn.execute(sql, (user_id, Membership.JOIN)) - results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) - - return results - - async def get_users_server_still_shares_room_with( - self, user_ids: Collection[str] - ) -> Set[str]: - """Given a list of users return the set that the server still share a - room with. - """ - - if not user_ids: - return set() - - def _get_users_server_still_shares_room_with_txn(txn): - sql = """ - SELECT state_key FROM current_state_events - WHERE - type = 'm.room.member' - AND membership = 'join' - AND %s - GROUP BY state_key - """ - - clause, args = make_in_list_sql_clause( - self.database_engine, "state_key", user_ids - ) - - txn.execute(sql % (clause,), args) - - return {row[0] for row in txn} - - return await self.db.runInteraction( - "get_users_server_still_shares_room_with", - _get_users_server_still_shares_room_with_txn, - ) - - @defer.inlineCallbacks - def get_rooms_for_user(self, user_id, on_invalidate=None): - """Returns a set of room_ids the user is currently joined to. - - If a remote user only returns rooms this server is currently - participating in. - """ - rooms = yield self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate - ) - return frozenset(r.room_id for r in rooms) - - @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) - def get_users_who_share_room_with_user(self, user_id, cache_context): - """Returns the set of users who share a room with `user_id` - """ - room_ids = yield self.get_rooms_for_user( - user_id, on_invalidate=cache_context.invalidate - ) - - user_who_share_room = set() - for room_id in room_ids: - user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - user_who_share_room.update(user_ids) - - return user_who_share_room - - @defer.inlineCallbacks - def get_joined_users_from_context(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) - result = yield self._get_joined_users_from_context( - event.room_id, state_group, current_state_ids, event=event, context=context - ) - return result - - @defer.inlineCallbacks - def get_joined_users_from_state(self, room_id, state_entry): - state_group = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - with Measure(self._clock, "get_joined_users_from_state"): - return ( - yield self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) - ) - - @cachedInlineCallbacks( - num_args=2, cache_context=True, iterable=True, max_entries=100000 - ) - def _get_joined_users_from_context( - self, - room_id, - state_group, - current_state_ids, - cache_context, - event=None, - context=None, - ): - # We don't use `state_group`, it's there so that we can cache based - # on it. However, it's important that it's never None, since two current_states - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - users_in_room = {} - member_event_ids = [ - e_id - for key, e_id in current_state_ids.items() - if key[0] == EventTypes.Member - ] - - if context is not None: - # If we have a context with a delta from a previous state group, - # check if we also have the result from the previous group in cache. - # If we do then we can reuse that result and simply update it with - # any membership changes in `delta_ids` - if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get( - (room_id, context.prev_group), None - ) - if prev_res and isinstance(prev_res, dict): - users_in_room = dict(prev_res) - member_event_ids = [ - e_id - for key, e_id in context.delta_ids.items() - if key[0] == EventTypes.Member - ] - for etype, state_key in context.delta_ids: - if etype == EventTypes.Member: - users_in_room.pop(state_key, None) - - # We check if we have any of the member event ids in the event cache - # before we ask the DB - - # We don't update the event cache hit ratio as it completely throws off - # the hit ratio counts. After all, we don't populate the cache if we - # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) - - missing_member_event_ids = [] - for event_id in member_event_ids: - ev_entry = event_map.get(event_id) - if ev_entry: - if ev_entry.event.membership == Membership.JOIN: - users_in_room[ev_entry.event.state_key] = ProfileInfo( - display_name=ev_entry.event.content.get("displayname", None), - avatar_url=ev_entry.event.content.get("avatar_url", None), - ) - else: - missing_member_event_ids.append(event_id) - - if missing_member_event_ids: - event_to_memberships = yield self._get_joined_profiles_from_event_ids( - missing_member_event_ids - ) - users_in_room.update((row for row in event_to_memberships.values() if row)) - - if event is not None and event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - if event.event_id in member_event_ids: - users_in_room[event.state_key] = ProfileInfo( - display_name=event.content.get("displayname", None), - avatar_url=event.content.get("avatar_url", None), - ) - - return users_in_room - - @cached(max_entries=10000) - def _get_joined_profile_from_event_id(self, event_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_joined_profile_from_event_id", - list_name="event_ids", - inlineCallbacks=True, - ) - def _get_joined_profiles_from_event_ids(self, event_ids): - """For given set of member event_ids check if they point to a join - event and if so return the associated user and profile info. - - Args: - event_ids (Iterable[str]): The member event IDs to lookup - - Returns: - Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID - to `user_id` and ProfileInfo (or None if not join event). - """ - - rows = yield self.db.simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=event_ids, - retcols=("user_id", "display_name", "avatar_url", "event_id"), - keyvalues={"membership": Membership.JOIN}, - batch_size=500, - desc="_get_membership_from_event_ids", - ) - - return { - row["event_id"]: ( - row["user_id"], - ProfileInfo( - avatar_url=row["avatar_url"], display_name=row["display_name"] - ), - ) - for row in rows - } - - @cachedInlineCallbacks(max_entries=10000) - def is_host_joined(self, room_id, host): - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT state_key FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE m.membership = 'join' - AND type = 'm.room.member' - AND c.room_id = ? - AND state_key LIKE ? - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @cachedInlineCallbacks() - def was_host_joined(self, room_id, host): - """Check whether the server is or ever was in the room. - - Args: - room_id (str) - host (str) - - Returns: - Deferred: Resolves to True if the host is/was in the room, otherwise - False. - """ - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT user_id FROM room_memberships - WHERE room_id = ? - AND user_id LIKE ? - AND membership = 'join' - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @defer.inlineCallbacks - def get_joined_hosts(self, room_id, state_entry): - state_group = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - with Measure(self._clock, "get_joined_hosts"): - return ( - yield self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry - ) - ) - - @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - # @defer.inlineCallbacks - def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - cache = yield self._get_joined_hosts_cache(room_id) - joined_hosts = yield cache.get_destinations(state_entry) - - return joined_hosts - - @cached(max_entries=10000) - def _get_joined_hosts_cache(self, room_id): - return _JoinedHostsCache(self, room_id) - - @cachedInlineCallbacks(num_args=2) - def did_forget(self, user_id, room_id): - """Returns whether user_id has elected to discard history for room_id. - - Returns False if they have since re-joined.""" - - def f(txn): - sql = ( - "SELECT" - " COUNT(*)" - " FROM" - " room_memberships" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - " AND" - " forgotten = 0" - ) - txn.execute(sql, (user_id, room_id)) - rows = txn.fetchall() - return rows[0][0] - - count = yield self.db.runInteraction("did_forget_membership", f) - return count == 0 - - @cached() - def get_forgotten_rooms_for_user(self, user_id): - """Gets all rooms the user has forgotten. - - Args: - user_id (str) - - Returns: - Deferred[set[str]] - """ - - def _get_forgotten_rooms_for_user_txn(txn): - # This is a slightly convoluted query that first looks up all rooms - # that the user has forgotten in the past, then rechecks that list - # to see if any have subsequently been updated. This is done so that - # we can use a partial index on `forgotten = 1` on the assumption - # that few users will actually forget many rooms. - # - # Note that a room is considered "forgotten" if *all* membership - # events for that user and room have the forgotten field set (as - # when a user forgets a room we update all rows for that user and - # room, not just the current one). - sql = """ - SELECT room_id, ( - SELECT count(*) FROM room_memberships - WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0 - ) AS count - FROM room_memberships AS m - WHERE user_id = ? AND forgotten = 1 - GROUP BY room_id, user_id; - """ - txn.execute(sql, (user_id,)) - return {row[0] for row in txn if row[1] == 0} - - return self.db.runInteraction( - "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn - ) - - @defer.inlineCallbacks - def get_rooms_user_has_been_in(self, user_id): - """Get all rooms that the user has ever been in. - - Args: - user_id (str) - - Returns: - Deferred[set[str]]: Set of room IDs. - """ - - room_ids = yield self.db.simple_select_onecol( - table="room_memberships", - keyvalues={"membership": Membership.JOIN, "user_id": user_id}, - retcol="room_id", - desc="get_rooms_user_has_been_in", - ) - - return set(room_ids) - - def get_membership_from_event_ids( - self, member_event_ids: Iterable[str] - ) -> List[dict]: - """Get user_id and membership of a set of event IDs. - """ - - return self.db.simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids, - retcols=("user_id", "membership", "event_id"), - keyvalues={}, - batch_size=500, - desc="get_membership_from_event_ids", - ) - - async def is_local_host_in_room_ignoring_users( - self, room_id: str, ignore_users: Collection[str] - ) -> bool: - """Check if there are any local users, excluding those in the given - list, in the room. - """ - - clause, args = make_in_list_sql_clause( - self.database_engine, "user_id", ignore_users - ) - - sql = """ - SELECT 1 FROM local_current_membership - WHERE - room_id = ? AND membership = ? - AND NOT (%s) - LIMIT 1 - """ % ( - clause, - ) - - def _is_local_host_in_room_ignoring_users_txn(txn): - txn.execute(sql, (room_id, Membership.JOIN, *args)) - - return bool(txn.fetchone()) - - return await self.db.runInteraction( - "is_local_host_in_room_ignoring_users", - _is_local_host_in_room_ignoring_users_txn, - ) - - -class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( - _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile - ) - self.db.updates.register_background_update_handler( - _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, - self._background_current_state_membership, - ) - self.db.updates.register_background_index_update( - "room_membership_forgotten_idx", - index_name="room_memberships_user_room_forgotten", - table="room_memberships", - columns=["user_id", "room_id"], - where_clause="forgotten = 1", - ) - - @defer.inlineCallbacks - def _background_add_membership_profile(self, progress, batch_size): - target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start - ) - max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 - ) - - INSERT_CLUMP_SIZE = 1000 - - def add_membership_profile_txn(txn): - sql = """ - SELECT stream_ordering, event_id, events.room_id, event_json.json - FROM events - INNER JOIN event_json USING (event_id) - INNER JOIN room_memberships USING (event_id) - WHERE ? <= stream_ordering AND stream_ordering < ? - AND type = 'm.room.member' - ORDER BY stream_ordering DESC - LIMIT ? - """ - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = self.db.cursor_to_dict(txn) - if not rows: - return 0 - - min_stream_id = rows[-1]["stream_ordering"] - - to_update = [] - for row in rows: - event_id = row["event_id"] - room_id = row["room_id"] - try: - event_json = db_to_json(row["json"]) - content = event_json["content"] - except Exception: - continue - - display_name = content.get("displayname", None) - avatar_url = content.get("avatar_url", None) - - if display_name or avatar_url: - to_update.append((display_name, avatar_url, event_id, room_id)) - - to_update_sql = """ - UPDATE room_memberships SET display_name = ?, avatar_url = ? - WHERE event_id = ? AND room_id = ? - """ - for index in range(0, len(to_update), INSERT_CLUMP_SIZE): - clump = to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(to_update_sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - } - - self.db.updates._background_update_progress_txn( - txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress - ) - - return len(rows) - - result = yield self.db.runInteraction( - _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn - ) - - if not result: - yield self.db.updates._end_background_update( - _MEMBERSHIP_PROFILE_UPDATE_NAME - ) - - return result - - @defer.inlineCallbacks - def _background_current_state_membership(self, progress, batch_size): - """Update the new membership column on current_state_events. - - This works by iterating over all rooms in alphebetical order. - """ - - def _background_current_state_membership_txn(txn, last_processed_room): - processed = 0 - while processed < batch_size: - txn.execute( - """ - SELECT MIN(room_id) FROM current_state_events WHERE room_id > ? - """, - (last_processed_room,), - ) - row = txn.fetchone() - if not row or not row[0]: - return processed, True - - (next_room,) = row - - sql = """ - UPDATE current_state_events - SET membership = ( - SELECT membership FROM room_memberships - WHERE event_id = current_state_events.event_id - ) - WHERE room_id = ? - """ - txn.execute(sql, (next_room,)) - processed += txn.rowcount - - last_processed_room = next_room - - self.db.updates._background_update_progress_txn( - txn, - _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, - {"last_processed_room": last_processed_room}, - ) - - return processed, False - - # If we haven't got a last processed room then just use the empty - # string, which will compare before all room IDs correctly. - last_processed_room = progress.get("last_processed_room", "") - - row_count, finished = yield self.db.runInteraction( - "_background_current_state_membership_update", - _background_current_state_membership_txn, - last_processed_room, - ) - - if finished: - yield self.db.updates._end_background_update( - _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME - ) - - return row_count - - -class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): - super(RoomMemberStore, self).__init__(database, db_conn, hs) - - def forget(self, user_id, room_id): - """Indicate that user_id wishes to discard history for room_id.""" - - def f(txn): - sql = ( - "UPDATE" - " room_memberships" - " SET" - " forgotten = 1" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - ) - txn.execute(sql, (user_id, room_id)) - - self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) - self._invalidate_cache_and_stream( - txn, self.get_forgotten_rooms_for_user, (user_id,) - ) - - return self.db.runInteraction("forget_membership", f) - - -class _JoinedHostsCache(object): - """Cache for joined hosts in a room that is optimised to handle updates - via state deltas. - """ - - def __init__(self, store, room_id): - self.store = store - self.room_id = room_id - - self.hosts_to_joined_users = {} - - self.state_group = object() - - self.linearizer = Linearizer("_JoinedHostsCache") - - self._len = 0 - - @defer.inlineCallbacks - def get_destinations(self, state_entry): - """Get set of destinations for a state entry - - Args: - state_entry(synapse.state._StateCacheEntry) - """ - if state_entry.state_group == self.state_group: - return frozenset(self.hosts_to_joined_users) - - with (yield self.linearizer.queue(())): - if state_entry.state_group == self.state_group: - pass - elif state_entry.prev_group == self.state_group: - for (typ, state_key), event_id in state_entry.delta_ids.items(): - if typ != EventTypes.Member: - continue - - host = intern_string(get_domain_from_id(state_key)) - user_id = state_key - known_joins = self.hosts_to_joined_users.setdefault(host, set()) - - event = yield self.store.get_event(event_id) - if event.membership == Membership.JOIN: - known_joins.add(user_id) - else: - known_joins.discard(user_id) - - if not known_joins: - self.hosts_to_joined_users.pop(host, None) - else: - joined_users = yield self.store.get_joined_users_from_state( - self.room_id, state_entry - ) - - self.hosts_to_joined_users = {} - for user_id in joined_users: - host = intern_string(get_domain_from_id(user_id)) - self.hosts_to_joined_users.setdefault(host, set()).add(user_id) - - if state_entry.state_group: - self.state_group = state_entry.state_group - else: - self.state_group = object() - self._len = sum(len(v) for v in self.hosts_to_joined_users.values()) - return frozenset(self.hosts_to_joined_users) - - def __len__(self): - return self._len diff --git a/synapse/storage/data_stores/main/schema/delta/12/v12.sql b/synapse/storage/data_stores/main/schema/delta/12/v12.sql deleted file mode 100644 index 5964c5aaac..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/12/v12.sql +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS rejections( - event_id TEXT NOT NULL, - reason TEXT NOT NULL, - last_check TEXT NOT NULL, - UNIQUE (event_id) -); - --- Push notification endpoints that users have configured -CREATE TABLE IF NOT EXISTS pushers ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_name TEXT NOT NULL, - profile_tag VARCHAR(32) NOT NULL, - kind VARCHAR(8) NOT NULL, - app_id VARCHAR(64) NOT NULL, - app_display_name VARCHAR(64) NOT NULL, - device_display_name VARCHAR(128) NOT NULL, - pushkey VARBINARY(512) NOT NULL, - ts BIGINT UNSIGNED NOT NULL, - lang VARCHAR(8), - data LONGBLOB, - last_token TEXT, - last_success BIGINT UNSIGNED, - failing_since BIGINT UNSIGNED, - UNIQUE (app_id, pushkey) -); - -CREATE TABLE IF NOT EXISTS push_rules ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_name TEXT NOT NULL, - rule_id TEXT NOT NULL, - priority_class TINYINT NOT NULL, - priority INTEGER NOT NULL DEFAULT 0, - conditions TEXT NOT NULL, - actions TEXT NOT NULL, - UNIQUE(user_name, rule_id) -); - -CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); - -CREATE TABLE IF NOT EXISTS user_filters( - user_id TEXT, - filter_id BIGINT UNSIGNED, - filter_json LONGBLOB -); - -CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( - user_id, filter_id -); diff --git a/synapse/storage/data_stores/main/schema/delta/13/v13.sql b/synapse/storage/data_stores/main/schema/delta/13/v13.sql deleted file mode 100644 index f8649e5d99..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/13/v13.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* We used to create a tables called application_services and - * application_services_regex, but these are no longer used and are removed in - * delta 54. - */ diff --git a/synapse/storage/data_stores/main/schema/delta/14/v14.sql b/synapse/storage/data_stores/main/schema/delta/14/v14.sql deleted file mode 100644 index a831920da6..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/14/v14.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS push_rules_enable ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_name TEXT NOT NULL, - rule_id TEXT NOT NULL, - enabled TINYINT, - UNIQUE(user_name, rule_id) -); - -CREATE INDEX IF NOT EXISTS push_rules_enable_user_name on push_rules_enable (user_name); diff --git a/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql deleted file mode 100644 index e4f5e76aec..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS application_services_state( - as_id TEXT PRIMARY KEY, - state VARCHAR(5), - last_txn INTEGER -); - -CREATE TABLE IF NOT EXISTS application_services_txns( - as_id TEXT NOT NULL, - txn_id INTEGER NOT NULL, - event_ids TEXT NOT NULL, - UNIQUE(as_id, txn_id) -); - -CREATE INDEX IF NOT EXISTS application_services_txns_id ON application_services_txns ( - as_id -); diff --git a/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql deleted file mode 100644 index 6b8d0f1ca7..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql +++ /dev/null @@ -1,2 +0,0 @@ - -CREATE INDEX IF NOT EXISTS presence_list_user_id ON presence_list (user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/15/v15.sql b/synapse/storage/data_stores/main/schema/delta/15/v15.sql deleted file mode 100644 index 9523d2bcc3..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/15/v15.sql +++ /dev/null @@ -1,24 +0,0 @@ --- Drop, copy & recreate pushers table to change unique key --- Also add access_token column at the same time -CREATE TABLE IF NOT EXISTS pushers2 ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - access_token BIGINT DEFAULT NULL, - profile_tag VARCHAR(32) NOT NULL, - kind VARCHAR(8) NOT NULL, - app_id VARCHAR(64) NOT NULL, - app_display_name VARCHAR(64) NOT NULL, - device_display_name VARCHAR(128) NOT NULL, - pushkey bytea NOT NULL, - ts BIGINT NOT NULL, - lang VARCHAR(8), - data bytea, - last_token TEXT, - last_success BIGINT, - failing_since BIGINT, - UNIQUE (app_id, pushkey) -); -INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since) - SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers; -DROP TABLE pushers; -ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql deleted file mode 100644 index a48f215170..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql +++ /dev/null @@ -1,4 +0,0 @@ -CREATE INDEX events_order ON events (topological_ordering, stream_ordering); -CREATE INDEX events_order_room ON events ( - room_id, topological_ordering, stream_ordering -); diff --git a/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql deleted file mode 100644 index 7a15265cb1..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql +++ /dev/null @@ -1,2 +0,0 @@ -CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id - ON remote_media_cache_thumbnails (media_id); \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql deleted file mode 100644 index 65c97b5e2f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql +++ /dev/null @@ -1,9 +0,0 @@ - - -DELETE FROM event_to_state_groups WHERE state_group not in ( - SELECT MAX(state_group) FROM event_to_state_groups GROUP BY event_id -); - -DELETE FROM event_to_state_groups WHERE rowid not in ( - SELECT MIN(rowid) FROM event_to_state_groups GROUP BY event_id -); diff --git a/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql deleted file mode 100644 index f82486132b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql +++ /dev/null @@ -1,3 +0,0 @@ - -CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id); -CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias); diff --git a/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql deleted file mode 100644 index 5b8de52c33..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql +++ /dev/null @@ -1,72 +0,0 @@ - --- We can use SQLite features here, since other db support was only added in v16 - --- -DELETE FROM current_state_events WHERE rowid not in ( - SELECT MIN(rowid) FROM current_state_events GROUP BY event_id -); - -DROP INDEX IF EXISTS current_state_events_event_id; -CREATE UNIQUE INDEX current_state_events_event_id ON current_state_events(event_id); - --- -DELETE FROM room_memberships WHERE rowid not in ( - SELECT MIN(rowid) FROM room_memberships GROUP BY event_id -); - -DROP INDEX IF EXISTS room_memberships_event_id; -CREATE UNIQUE INDEX room_memberships_event_id ON room_memberships(event_id); - --- -DELETE FROM topics WHERE rowid not in ( - SELECT MIN(rowid) FROM topics GROUP BY event_id -); - -DROP INDEX IF EXISTS topics_event_id; -CREATE UNIQUE INDEX topics_event_id ON topics(event_id); - --- -DELETE FROM room_names WHERE rowid not in ( - SELECT MIN(rowid) FROM room_names GROUP BY event_id -); - -DROP INDEX IF EXISTS room_names_id; -CREATE UNIQUE INDEX room_names_id ON room_names(event_id); - --- -DELETE FROM presence WHERE rowid not in ( - SELECT MIN(rowid) FROM presence GROUP BY user_id -); - -DROP INDEX IF EXISTS presence_id; -CREATE UNIQUE INDEX presence_id ON presence(user_id); - --- -DELETE FROM presence_allow_inbound WHERE rowid not in ( - SELECT MIN(rowid) FROM presence_allow_inbound - GROUP BY observed_user_id, observer_user_id -); - -DROP INDEX IF EXISTS presence_allow_inbound_observers; -CREATE UNIQUE INDEX presence_allow_inbound_observers ON presence_allow_inbound( - observed_user_id, observer_user_id -); - --- -DELETE FROM presence_list WHERE rowid not in ( - SELECT MIN(rowid) FROM presence_list - GROUP BY user_id, observed_user_id -); - -DROP INDEX IF EXISTS presence_list_observers; -CREATE UNIQUE INDEX presence_list_observers ON presence_list( - user_id, observed_user_id -); - --- -DELETE FROM room_aliases WHERE rowid not in ( - SELECT MIN(rowid) FROM room_aliases GROUP BY room_alias -); - -DROP INDEX IF EXISTS room_aliases_id; -CREATE INDEX room_aliases_id ON room_aliases(room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/16/users.sql b/synapse/storage/data_stores/main/schema/delta/16/users.sql deleted file mode 100644 index cd0709250d..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/16/users.sql +++ /dev/null @@ -1,56 +0,0 @@ --- Convert `access_tokens`.user from rowids to user strings. --- MUST BE DONE BEFORE REMOVING ID COLUMN FROM USERS TABLE BELOW -CREATE TABLE IF NOT EXISTS new_access_tokens( - id BIGINT UNSIGNED PRIMARY KEY, - user_id TEXT NOT NULL, - device_id TEXT, - token TEXT NOT NULL, - last_used BIGINT UNSIGNED, - UNIQUE(token) -); - -INSERT INTO new_access_tokens - SELECT a.id, u.name, a.device_id, a.token, a.last_used - FROM access_tokens as a - INNER JOIN users as u ON u.id = a.user_id; - -DROP TABLE access_tokens; - -ALTER TABLE new_access_tokens RENAME TO access_tokens; - --- Remove ID column from `users` table -CREATE TABLE IF NOT EXISTS new_users( - name TEXT, - password_hash TEXT, - creation_ts BIGINT UNSIGNED, - admin BOOL DEFAULT 0 NOT NULL, - UNIQUE(name) -); - -INSERT INTO new_users SELECT name, password_hash, creation_ts, admin FROM users; - -DROP TABLE users; - -ALTER TABLE new_users RENAME TO users; - - --- Remove UNIQUE constraint from `user_ips` table -CREATE TABLE IF NOT EXISTS new_user_ips ( - user_id TEXT NOT NULL, - access_token TEXT NOT NULL, - device_id TEXT, - ip TEXT NOT NULL, - user_agent TEXT NOT NULL, - last_seen BIGINT UNSIGNED NOT NULL -); - -INSERT INTO new_user_ips - SELECT user, access_token, device_id, ip, user_agent, last_seen FROM user_ips; - -DROP TABLE user_ips; - -ALTER TABLE new_user_ips RENAME TO user_ips; - -CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user_id); -CREATE INDEX IF NOT EXISTS user_ips_user_ip ON user_ips(user_id, access_token, ip); - diff --git a/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql deleted file mode 100644 index 7c9a90e27f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP INDEX IF EXISTS sent_transaction_dest; -DROP INDEX IF EXISTS sent_transaction_sent; -DROP INDEX IF EXISTS user_ips_user; diff --git a/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql deleted file mode 100644 index 70b247a06b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS server_keys_json ( - server_name TEXT, -- Server name. - key_id TEXT, -- Requested key id. - from_server TEXT, -- Which server the keys were fetched from. - ts_added_ms INTEGER, -- When the keys were fetched - ts_valid_until_ms INTEGER, -- When this version of the keys exipires. - key_json bytea, -- JSON certificate for the remote server. - CONSTRAINT uniqueness UNIQUE (server_name, key_id, from_server) -); diff --git a/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql deleted file mode 100644 index c17715ac80..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql +++ /dev/null @@ -1,9 +0,0 @@ -CREATE TABLE user_threepids ( - user_id TEXT NOT NULL, - medium TEXT NOT NULL, - address TEXT NOT NULL, - validated_at BIGINT NOT NULL, - added_at BIGINT NOT NULL, - CONSTRAINT user_medium_address UNIQUE (user_id, medium, address) -); -CREATE INDEX user_threepids_user_id ON user_threepids(user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql deleted file mode 100644 index 6e0871c92b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE IF NOT EXISTS new_server_keys_json ( - server_name TEXT NOT NULL, -- Server name. - key_id TEXT NOT NULL, -- Requested key id. - from_server TEXT NOT NULL, -- Which server the keys were fetched from. - ts_added_ms BIGINT NOT NULL, -- When the keys were fetched - ts_valid_until_ms BIGINT NOT NULL, -- When this version of the keys exipires. - key_json bytea NOT NULL, -- JSON certificate for the remote server. - CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server) -); - -INSERT INTO new_server_keys_json - SELECT server_name, key_id, from_server,ts_added_ms, ts_valid_until_ms, key_json FROM server_keys_json ; - -DROP TABLE server_keys_json; - -ALTER TABLE new_server_keys_json RENAME TO server_keys_json; diff --git a/synapse/storage/data_stores/main/schema/delta/19/event_index.sql b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql deleted file mode 100644 index 18b97b4332..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/19/event_index.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE INDEX events_order_topo_stream_room ON events( - topological_ordering, stream_ordering, room_id -); diff --git a/synapse/storage/data_stores/main/schema/delta/20/dummy.sql b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql deleted file mode 100644 index e0ac49d1ec..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/20/dummy.sql +++ /dev/null @@ -1 +0,0 @@ -SELECT 1; diff --git a/synapse/storage/data_stores/main/schema/delta/20/pushers.py b/synapse/storage/data_stores/main/schema/delta/20/pushers.py deleted file mode 100644 index 3edfcfd783..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/20/pushers.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -Main purpose of this upgrade is to change the unique key on the -pushers table again (it was missed when the v16 full schema was -made) but this also changes the pushkey and data columns to text. -When selecting a bytea column into a text column, postgres inserts -the hex encoded data, and there's no portable way of getting the -UTF-8 bytes, so we have to do it in Python. -""" - -import logging - -logger = logging.getLogger(__name__) - - -def run_create(cur, database_engine, *args, **kwargs): - logger.info("Porting pushers table...") - cur.execute( - """ - CREATE TABLE IF NOT EXISTS pushers2 ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - access_token BIGINT DEFAULT NULL, - profile_tag VARCHAR(32) NOT NULL, - kind VARCHAR(8) NOT NULL, - app_id VARCHAR(64) NOT NULL, - app_display_name VARCHAR(64) NOT NULL, - device_display_name VARCHAR(128) NOT NULL, - pushkey TEXT NOT NULL, - ts BIGINT NOT NULL, - lang VARCHAR(8), - data TEXT, - last_token TEXT, - last_success BIGINT, - failing_since BIGINT, - UNIQUE (app_id, pushkey, user_name) - ) - """ - ) - cur.execute( - """SELECT - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - FROM pushers - """ - ) - count = 0 - for row in cur.fetchall(): - row = list(row) - row[8] = bytes(row[8]).decode("utf-8") - row[11] = bytes(row[11]).decode("utf-8") - cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), - row, - ) - count += 1 - cur.execute("DROP TABLE pushers") - cur.execute("ALTER TABLE pushers2 RENAME TO pushers") - logger.info("Moved %d pushers to new table", count) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql deleted file mode 100644 index 4c2fb20b77..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE IF NOT EXISTS e2e_device_keys_json ( - user_id TEXT NOT NULL, -- The user these keys are for. - device_id TEXT NOT NULL, -- Which of the user's devices these keys are for. - ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded. - key_json TEXT NOT NULL, -- The keys for the device as a JSON blob. - CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) -); - - -CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json ( - user_id TEXT NOT NULL, -- The user this one-time key is for. - device_id TEXT NOT NULL, -- The device this one-time key is for. - algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for. - key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. - ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. - key_json TEXT NOT NULL, -- The key as a JSON blob. - CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) -); diff --git a/synapse/storage/data_stores/main/schema/delta/21/receipts.sql b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql deleted file mode 100644 index d070845477..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/21/receipts.sql +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE IF NOT EXISTS receipts_graph( - room_id TEXT NOT NULL, - receipt_type TEXT NOT NULL, - user_id TEXT NOT NULL, - event_ids TEXT NOT NULL, - data TEXT NOT NULL, - CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id) -); - -CREATE TABLE IF NOT EXISTS receipts_linearized ( - stream_id BIGINT NOT NULL, - room_id TEXT NOT NULL, - receipt_type TEXT NOT NULL, - user_id TEXT NOT NULL, - event_id TEXT NOT NULL, - data TEXT NOT NULL, - CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id) -); - -CREATE INDEX receipts_linearized_id ON receipts_linearized( - stream_id -); diff --git a/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql deleted file mode 100644 index bfc0b3bcaa..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ -CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( - room_id, stream_id -); diff --git a/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql deleted file mode 100644 index 87edfa454c..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql +++ /dev/null @@ -1,19 +0,0 @@ -CREATE TABLE IF NOT EXISTS user_threepids2 ( - user_id TEXT NOT NULL, - medium TEXT NOT NULL, - address TEXT NOT NULL, - validated_at BIGINT NOT NULL, - added_at BIGINT NOT NULL, - CONSTRAINT medium_address UNIQUE (medium, address) -); - -INSERT INTO user_threepids2 - SELECT * FROM user_threepids WHERE added_at IN ( - SELECT max(added_at) FROM user_threepids GROUP BY medium, address - ) -; - -DROP TABLE user_threepids; -ALTER TABLE user_threepids2 RENAME TO user_threepids; - -CREATE INDEX user_threepids_user_id ON user_threepids(user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql deleted file mode 100644 index acea7483bd..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - /* We used to create a table called stats_reporting, but this is no longer - * used and is removed in delta 54. - */ \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py deleted file mode 100644 index ee675e71ff..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/25/fts.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import logging - -from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -POSTGRES_TABLE = """ -CREATE TABLE IF NOT EXISTS event_search ( - event_id TEXT, - room_id TEXT, - sender TEXT, - key TEXT, - vector tsvector -); - -CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); -CREATE INDEX event_search_ev_idx ON event_search(event_id); -CREATE INDEX event_search_ev_ridx ON event_search(room_id); -""" - - -SQLITE_TABLE = ( - "CREATE VIRTUAL TABLE event_search" - " USING fts4 ( event_id, room_id, sender, key, value )" -) - - -def run_create(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - for statement in get_statements(POSTGRES_TABLE.splitlines()): - cur.execute(statement) - elif isinstance(database_engine, Sqlite3Engine): - cur.execute(SQLITE_TABLE) - else: - raise Exception("Unrecognized database engine") - - cur.execute("SELECT MIN(stream_ordering) FROM events") - rows = cur.fetchall() - min_stream_id = rows[0][0] - - cur.execute("SELECT MAX(stream_ordering) FROM events") - rows = cur.fetchall() - max_stream_id = rows[0][0] - - if min_stream_id is not None and max_stream_id is not None: - progress = { - "target_min_stream_id_inclusive": min_stream_id, - "max_stream_id_exclusive": max_stream_id + 1, - "rows_inserted": 0, - } - progress_json = json.dumps(progress) - - sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" - ) - - sql = database_engine.convert_param_style(sql) - - cur.execute(sql, ("event_search", progress_json)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql deleted file mode 100644 index 1ea389b471..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * This is a manual index of guest_access content of state events, - * so that we can join on them in SELECT statements. - */ -CREATE TABLE IF NOT EXISTS guest_access( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - guest_access TEXT NOT NULL, - UNIQUE (event_id) -); diff --git a/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql deleted file mode 100644 index f468fc1897..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * This is a manual index of history_visibility content of state events, - * so that we can join on them in SELECT statements. - */ -CREATE TABLE IF NOT EXISTS history_visibility( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - history_visibility TEXT NOT NULL, - UNIQUE (event_id) -); diff --git a/synapse/storage/data_stores/main/schema/delta/25/tags.sql b/synapse/storage/data_stores/main/schema/delta/25/tags.sql deleted file mode 100644 index 7a32ce68e4..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/25/tags.sql +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE IF NOT EXISTS room_tags( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - tag TEXT NOT NULL, -- The name of the tag. - content TEXT NOT NULL, -- The JSON content of the tag. - CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) -); - -CREATE TABLE IF NOT EXISTS room_tags_revisions ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - stream_id BIGINT NOT NULL, -- The current version of the room tags. - CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) -); - -CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_id BIGINT NOT NULL, - CHECK (Lock='X') -); - -INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0); diff --git a/synapse/storage/data_stores/main/schema/delta/26/account_data.sql b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql deleted file mode 100644 index e395de2b5e..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/26/account_data.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id; diff --git a/synapse/storage/data_stores/main/schema/delta/27/account_data.sql b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql deleted file mode 100644 index bf0558b5b3..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/27/account_data.sql +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS account_data( - user_id TEXT NOT NULL, - account_data_type TEXT NOT NULL, -- The type of the account_data. - stream_id BIGINT NOT NULL, -- The version of the account_data. - content TEXT NOT NULL, -- The JSON content of the account_data - CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) -); - - -CREATE TABLE IF NOT EXISTS room_account_data( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - account_data_type TEXT NOT NULL, -- The type of the account_data. - stream_id BIGINT NOT NULL, -- The version of the account_data. - content TEXT NOT NULL, -- The JSON content of the account_data - CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) -); - - -CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); -CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql deleted file mode 100644 index e2094f37fe..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Keeps track of what rooms users have left and don't want to be able to - * access again. - * - * If all users on this server have left a room, we can delete the room - * entirely. - * - * This column should always contain either 0 or 1. - */ - - ALTER TABLE room_memberships ADD COLUMN forgotten INTEGER DEFAULT 0; diff --git a/synapse/storage/data_stores/main/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py deleted file mode 100644 index b7972cfa8e..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/27/ts.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import logging - -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -ALTER_TABLE = ( - "ALTER TABLE events ADD COLUMN origin_server_ts BIGINT;" - "CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);" -) - - -def run_create(cur, database_engine, *args, **kwargs): - for statement in get_statements(ALTER_TABLE.splitlines()): - cur.execute(statement) - - cur.execute("SELECT MIN(stream_ordering) FROM events") - rows = cur.fetchall() - min_stream_id = rows[0][0] - - cur.execute("SELECT MAX(stream_ordering) FROM events") - rows = cur.fetchall() - max_stream_id = rows[0][0] - - if min_stream_id is not None and max_stream_id is not None: - progress = { - "target_min_stream_id_inclusive": min_stream_id, - "max_stream_id_exclusive": max_stream_id + 1, - "rows_inserted": 0, - } - progress_json = json.dumps(progress) - - sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" - ) - - sql = database_engine.convert_param_style(sql) - - cur.execute(sql, ("event_origin_server_ts", progress_json)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql deleted file mode 100644 index 4d519849df..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS event_push_actions( - room_id TEXT NOT NULL, - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - profile_tag VARCHAR(32), - actions TEXT NOT NULL, - CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) -); - - -CREATE INDEX event_push_actions_room_id_event_id_user_id_profile_tag on event_push_actions(room_id, event_id, user_id, profile_tag); -CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql deleted file mode 100644 index 36609475f1..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ -CREATE INDEX events_room_stream on events(room_id, stream_ordering); diff --git a/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql deleted file mode 100644 index 6c1fd68c5b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ -CREATE INDEX public_room_index on rooms(is_public); diff --git a/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql deleted file mode 100644 index cb84c69baa..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ -CREATE INDEX receipts_linearized_user ON receipts_linearized( - user_id -); diff --git a/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql deleted file mode 100644 index 3e4a9ab455..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Stores the timestamp when a user upgraded from a guest to a full user, if - * that happened. - */ - -ALTER TABLE users ADD COLUMN upgrade_ts BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql deleted file mode 100644 index 21d2b420bf..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE users ADD is_guest SMALLINT DEFAULT 0 NOT NULL; -/* - * NB: any guest users created between 27 and 28 will be incorrectly - * marked as not guests: we don't bother to fill these in correctly - * because guest access is not really complete in 27 anyway so it's - * very unlikley there will be any guest users created. - */ diff --git a/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql deleted file mode 100644 index 84b21cf813..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT; -ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT; -ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT; -ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT; - -UPDATE event_push_actions SET stream_ordering = ( - SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id -), topological_ordering = ( - SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id -); - -UPDATE event_push_actions SET notif = 1, highlight = 0; - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ -CREATE INDEX event_push_actions_rm_tokens on event_push_actions( - user_id, room_id, topological_ordering, stream_ordering -); diff --git a/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql deleted file mode 100644 index c9d0dde638..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE room_aliases ADD COLUMN creator TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py deleted file mode 100644 index b42c02710a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/30/as_users.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from synapse.config.appservice import load_appservices - -logger = logging.getLogger(__name__) - - -def run_create(cur, database_engine, *args, **kwargs): - # NULL indicates user was not registered by an appservice. - try: - cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") - except Exception: - # Maybe we already added the column? Hope so... - pass - - -def run_upgrade(cur, database_engine, config, *args, **kwargs): - cur.execute("SELECT name FROM users") - rows = cur.fetchall() - - config_files = [] - try: - config_files = config.app_service_config_files - except AttributeError: - logger.warning("Could not get app_service_config_files from config") - pass - - appservices = load_appservices(config.server_name, config_files) - - owned = {} - - for row in rows: - user_id = row[0] - for appservice in appservices: - if appservice.is_exclusive_user(user_id): - if user_id in owned.keys(): - logger.error( - "user_id %s was owned by more than one application" - " service (IDs %s and %s); assigning arbitrarily to %s" - % (user_id, owned[user_id], appservice.id, owned[user_id]) - ) - owned.setdefault(appservice.id, []).append(user_id) - - for as_id, user_ids in owned.items(): - n = 100 - user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) - for chunk in user_chunks: - cur.execute( - database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" - % (",".join("?" for _ in chunk),) - ), - [as_id] + chunk, - ) diff --git a/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql deleted file mode 100644 index 712c454aa1..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS deleted_pushers( - stream_id BIGINT NOT NULL, - app_id TEXT NOT NULL, - pushkey TEXT NOT NULL, - user_id TEXT NOT NULL, - /* We only track the most recent delete for each app_id, pushkey and user_id. */ - UNIQUE (app_id, pushkey, user_id) -); - -CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql deleted file mode 100644 index 606bbb037d..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - - CREATE TABLE presence_stream( - stream_id BIGINT, - user_id TEXT, - state TEXT, - last_active_ts BIGINT, - last_federation_update_ts BIGINT, - last_user_sync_ts BIGINT, - status_msg TEXT, - currently_active BOOLEAN - ); - - CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); - CREATE INDEX presence_stream_user_id ON presence_stream(user_id); - CREATE INDEX presence_stream_state ON presence_stream(state); diff --git a/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql deleted file mode 100644 index f09db4faa6..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -/* This release removes the restriction that published rooms must have an alias, - * so we go back and ensure the only 'public' rooms are ones with an alias. - * We use (1 = 0) and (1 = 1) so that it works in both postgres and sqlite - */ -UPDATE rooms SET is_public = (1 = 0) WHERE is_public = (1 = 1) AND room_id not in ( - SELECT room_id FROM room_aliases -); diff --git a/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql deleted file mode 100644 index 735aa8d5f6..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - - -CREATE TABLE push_rules_stream( - stream_id BIGINT NOT NULL, - event_stream_ordering BIGINT NOT NULL, - user_id TEXT NOT NULL, - rule_id TEXT NOT NULL, - op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE" - priority_class SMALLINT, - priority INTEGER, - conditions TEXT, - actions TEXT -); - --- The extra data for each operation is: --- * ENABLE, DISABLE, DELETE: [] --- * ACTIONS: ["actions"] --- * ADD: ["priority_class", "priority", "actions", "conditions"] - --- Index for replication queries. -CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); --- Index for /sync queries. -CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql deleted file mode 100644 index 0dd2f1360c..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Stores guest account access tokens generated for unbound 3pids. -CREATE TABLE threepid_guest_access_tokens( - medium TEXT, -- The medium of the 3pid. Must be "email". - address TEXT, -- The 3pid address. - guest_access_token TEXT, -- The access token for a guest user for this 3pid. - first_inviter TEXT -- User ID of the first user to invite this 3pid to a room. -); - -CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); diff --git a/synapse/storage/data_stores/main/schema/delta/31/invites.sql b/synapse/storage/data_stores/main/schema/delta/31/invites.sql deleted file mode 100644 index 2c57846d5a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/31/invites.sql +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE local_invites( - stream_id BIGINT NOT NULL, - inviter TEXT NOT NULL, - invitee TEXT NOT NULL, - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - locally_rejected TEXT, - replaced_by TEXT -); - --- Insert all invites for local users into new `invites` table -INSERT INTO local_invites SELECT - stream_ordering as stream_id, - sender as inviter, - state_key as invitee, - event_id, - room_id, - NULL as locally_rejected, - NULL as replaced_by - FROM events - NATURAL JOIN current_state_events - NATURAL JOIN room_memberships - WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); - -CREATE INDEX local_invites_id ON local_invites(stream_id); -CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql deleted file mode 100644 index 9efb4280eb..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE local_media_repository_url_cache( - url TEXT, -- the URL being cached - response_code INTEGER, -- the HTTP response code of this download attempt - etag TEXT, -- the etag header of this response - expires INTEGER, -- the number of ms this response was valid for - og TEXT, -- cache of the OG metadata of this URL as JSON - media_id TEXT, -- the media_id, if any, of the URL's content in the repo - download_ts BIGINT -- the timestamp of this download attempt -); - -CREATE INDEX local_media_repository_url_cache_by_url_download_ts - ON local_media_repository_url_cache(url, download_ts); diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers.py b/synapse/storage/data_stores/main/schema/delta/31/pushers.py deleted file mode 100644 index 9bb504aad5..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/31/pushers.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Change the last_token to last_stream_ordering now that pushers no longer -# listen on an event stream but instead select out of the event_push_actions -# table. - - -import logging - -logger = logging.getLogger(__name__) - - -def token_to_stream_ordering(token): - return int(token[1:].split("_")[0]) - - -def run_create(cur, database_engine, *args, **kwargs): - logger.info("Porting pushers table, delta 31...") - cur.execute( - """ - CREATE TABLE IF NOT EXISTS pushers2 ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - access_token BIGINT DEFAULT NULL, - profile_tag VARCHAR(32) NOT NULL, - kind VARCHAR(8) NOT NULL, - app_id VARCHAR(64) NOT NULL, - app_display_name VARCHAR(64) NOT NULL, - device_display_name VARCHAR(128) NOT NULL, - pushkey TEXT NOT NULL, - ts BIGINT NOT NULL, - lang VARCHAR(8), - data TEXT, - last_stream_ordering INTEGER, - last_success BIGINT, - failing_since BIGINT, - UNIQUE (app_id, pushkey, user_name) - ) - """ - ) - cur.execute( - """SELECT - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - FROM pushers - """ - ) - count = 0 - for row in cur.fetchall(): - row = list(row) - row[12] = token_to_stream_ordering(row[12]) - cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_stream_ordering, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), - row, - ) - count += 1 - cur.execute("DROP TABLE pushers") - cur.execute("ALTER TABLE pushers2 RENAME TO pushers") - logger.info("Moved %d pushers to new table", count) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql deleted file mode 100644 index a82add88fd..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ - CREATE INDEX event_push_actions_stream_ordering on event_push_actions( - stream_ordering, user_id - ); diff --git a/synapse/storage/data_stores/main/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py deleted file mode 100644 index 63b757ade6..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/31/search_update.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import logging - -from synapse.storage.engines import PostgresEngine -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -ALTER_TABLE = """ -ALTER TABLE event_search ADD COLUMN origin_server_ts BIGINT; -ALTER TABLE event_search ADD COLUMN stream_ordering BIGINT; -""" - - -def run_create(cur, database_engine, *args, **kwargs): - if not isinstance(database_engine, PostgresEngine): - return - - for statement in get_statements(ALTER_TABLE.splitlines()): - cur.execute(statement) - - cur.execute("SELECT MIN(stream_ordering) FROM events") - rows = cur.fetchall() - min_stream_id = rows[0][0] - - cur.execute("SELECT MAX(stream_ordering) FROM events") - rows = cur.fetchall() - max_stream_id = rows[0][0] - - if min_stream_id is not None and max_stream_id is not None: - progress = { - "target_min_stream_id_inclusive": min_stream_id, - "max_stream_id_exclusive": max_stream_id + 1, - "rows_inserted": 0, - "have_added_indexes": False, - } - progress_json = json.dumps(progress) - - sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" - ) - - sql = database_engine.convert_param_style(sql) - - cur.execute(sql, ("event_search_order", progress_json)) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/32/events.sql b/synapse/storage/data_stores/main/schema/delta/32/events.sql deleted file mode 100644 index 1dd0f9e170..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/32/events.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE events ADD COLUMN received_ts BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/32/openid.sql b/synapse/storage/data_stores/main/schema/delta/32/openid.sql deleted file mode 100644 index 36f37b11c8..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/32/openid.sql +++ /dev/null @@ -1,9 +0,0 @@ - -CREATE TABLE open_id_tokens ( - token TEXT NOT NULL PRIMARY KEY, - ts_valid_until_ms bigint NOT NULL, - user_id TEXT NOT NULL, - UNIQUE (token) -); - -CREATE index open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms); diff --git a/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql deleted file mode 100644 index d86d30c13c..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE pusher_throttle( - pusher BIGINT NOT NULL, - room_id TEXT NOT NULL, - last_sent_ts BIGINT, - throttle_ms BIGINT, - PRIMARY KEY (pusher, room_id) -); diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql deleted file mode 100644 index 2de50d408c..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - --- The following indices are redundant, other indices are equivalent or --- supersets -DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream -DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room -DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room -DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT - -DROP INDEX IF EXISTS st_extrem_id; -- Prefix of UNIQUE CONSTRAINT -DROP INDEX IF EXISTS event_signatures_id; -- Prefix of UNIQUE CONSTRAINT -DROP INDEX IF EXISTS redactions_event_id; -- Duplicate of UNIQUE CONSTRAINT - --- The following indices were unused -DROP INDEX IF EXISTS remote_media_cache_thumbnails_media_id; -DROP INDEX IF EXISTS evauth_edges_auth_id; -DROP INDEX IF EXISTS presence_stream_state; diff --git a/synapse/storage/data_stores/main/schema/delta/32/reports.sql b/synapse/storage/data_stores/main/schema/delta/32/reports.sql deleted file mode 100644 index d13609776f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/32/reports.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE event_reports( - id BIGINT NOT NULL PRIMARY KEY, - received_ts BIGINT NOT NULL, - room_id TEXT NOT NULL, - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - reason TEXT, - content TEXT -); diff --git a/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql deleted file mode 100644 index 61ad3fe3e8..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices.sql b/synapse/storage/data_stores/main/schema/delta/33/devices.sql deleted file mode 100644 index eca7268d82..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/33/devices.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE devices ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - display_name TEXT, - CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) -); diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql deleted file mode 100644 index aa4a3b9f2f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- make sure that we have a device record for each set of E2E keys, so that the --- user can delete them if they like. -INSERT INTO devices - SELECT user_id, device_id, NULL FROM e2e_device_keys_json; diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql deleted file mode 100644 index 6671573398..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- a previous version of the "devices_for_e2e_keys" delta set all the device --- names to "unknown device". This wasn't terribly helpful -UPDATE devices - SET display_name = NULL - WHERE display_name = 'unknown device'; diff --git a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py deleted file mode 100644 index a3e81eeac7..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import logging - -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -ALTER_TABLE = """ -ALTER TABLE events ADD COLUMN sender TEXT; -ALTER TABLE events ADD COLUMN contains_url BOOLEAN; -""" - - -def run_create(cur, database_engine, *args, **kwargs): - for statement in get_statements(ALTER_TABLE.splitlines()): - cur.execute(statement) - - cur.execute("SELECT MIN(stream_ordering) FROM events") - rows = cur.fetchall() - min_stream_id = rows[0][0] - - cur.execute("SELECT MAX(stream_ordering) FROM events") - rows = cur.fetchall() - max_stream_id = rows[0][0] - - if min_stream_id is not None and max_stream_id is not None: - progress = { - "target_min_stream_id_inclusive": min_stream_id, - "max_stream_id_exclusive": max_stream_id + 1, - "rows_inserted": 0, - } - progress_json = json.dumps(progress) - - sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" - ) - - sql = database_engine.convert_param_style(sql) - - cur.execute(sql, ("event_fields_sender_url", progress_json)) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py deleted file mode 100644 index a26057dfb6..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT" - - -def run_create(cur, database_engine, *args, **kwargs): - cur.execute(ALTER_TABLE) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - cur.execute( - database_engine.convert_param_style( - "UPDATE remote_media_cache SET last_access_ts = ?" - ), - (int(time.time() * 1000),), - ) diff --git a/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql deleted file mode 100644 index 473f75a78e..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('user_ips_device_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql deleted file mode 100644 index 69e16eda0f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS appservice_stream_position( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_ordering BIGINT, - CHECK (Lock='X') -); - -INSERT INTO appservice_stream_position (stream_ordering) - SELECT COALESCE(MAX(stream_ordering), 0) FROM events; diff --git a/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py deleted file mode 100644 index cf09e43e2b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.engines import PostgresEngine -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -# This stream is used to notify replication slaves that some caches have -# been invalidated that they cannot infer from the other streams. -CREATE_TABLE = """ -CREATE TABLE cache_invalidation_stream ( - stream_id BIGINT, - cache_func TEXT, - keys TEXT[], - invalidation_ts BIGINT -); - -CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id); -""" - - -def run_create(cur, database_engine, *args, **kwargs): - if not isinstance(database_engine, PostgresEngine): - return - - for statement in get_statements(CREATE_TABLE.splitlines()): - cur.execute(statement) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql deleted file mode 100644 index e68844c74a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE device_inbox ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - stream_id BIGINT NOT NULL, - message_json TEXT NOT NULL -- {"type":, "sender":, "content",} -); - -CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); -CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql deleted file mode 100644 index 0d9fe1a99a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name'; -UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; - -DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name'; -UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; diff --git a/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py deleted file mode 100644 index 67d505e68b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.engines import PostgresEngine - -logger = logging.getLogger(__name__) - - -def run_create(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - cur.execute("TRUNCATE received_transactions") - else: - cur.execute("DELETE FROM received_transactions") - - cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)") - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql deleted file mode 100644 index 6cd123027b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - INSERT into background_updates (update_name, progress_json) - VALUES ('event_contains_url_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql deleted file mode 100644 index 17e6c43105..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP TABLE IF EXISTS device_federation_outbox; -CREATE TABLE device_federation_outbox ( - destination TEXT NOT NULL, - stream_id BIGINT NOT NULL, - queued_ts BIGINT NOT NULL, - messages_json TEXT NOT NULL -); - - -DROP INDEX IF EXISTS device_federation_outbox_destination_id; -CREATE INDEX device_federation_outbox_destination_id - ON device_federation_outbox(destination, stream_id); - - -DROP TABLE IF EXISTS device_federation_inbox; -CREATE TABLE device_federation_inbox ( - origin TEXT NOT NULL, - message_id TEXT NOT NULL, - received_ts BIGINT NOT NULL -); - -DROP INDEX IF EXISTS device_federation_inbox_sender_id; -CREATE INDEX device_federation_inbox_sender_id - ON device_federation_inbox(origin, message_id); diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql deleted file mode 100644 index 7ab7d942e2..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE device_max_stream_id ( - stream_id BIGINT NOT NULL -); - -INSERT INTO device_max_stream_id (stream_id) - SELECT COALESCE(MAX(stream_id), 0) FROM device_inbox; diff --git a/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql deleted file mode 100644 index 2e836d8e9c..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - INSERT into background_updates (update_name, progress_json) - VALUES ('epa_highlight_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql deleted file mode 100644 index dd2bf2e28a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE public_room_list_stream ( - stream_id BIGINT NOT NULL, - room_id TEXT NOT NULL, - visibility BOOLEAN NOT NULL -); - -INSERT INTO public_room_list_stream (stream_id, room_id, visibility) - SELECT 1, room_id, is_public FROM rooms - WHERE is_public = CAST(1 AS BOOLEAN); - -CREATE INDEX public_room_list_stream_idx on public_room_list_stream( - stream_id -); - -CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( - room_id, stream_id -); diff --git a/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql deleted file mode 100644 index 2b945d8a57..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE stream_ordering_to_exterm ( - stream_ordering BIGINT NOT NULL, - room_id TEXT NOT NULL, - event_id TEXT NOT NULL -); - -INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) - SELECT stream_ordering, room_id, event_id FROM event_forward_extremities - INNER JOIN ( - SELECT room_id, max(stream_ordering) as stream_ordering FROM events - INNER JOIN event_forward_extremities USING (room_id, event_id) - GROUP BY room_id - ) AS rms USING (room_id); - -CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( - stream_ordering -); - -CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( - room_id, stream_ordering -); diff --git a/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql deleted file mode 100644 index 90d8fd18f9..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Re-add some entries to stream_ordering_to_exterm that were incorrectly deleted -INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) - SELECT - (SELECT stream_ordering FROM events where event_id = e.event_id) AS stream_ordering, - room_id, - event_id - FROM event_forward_extremities AS e - WHERE NOT EXISTS ( - SELECT room_id FROM stream_ordering_to_exterm AS s - WHERE s.room_id = e.room_id - ); diff --git a/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py deleted file mode 100644 index a377884169..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.engines import PostgresEngine -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - -DROP_INDICES = """ --- We only ever query based on event_id -DROP INDEX IF EXISTS state_events_room_id; -DROP INDEX IF EXISTS state_events_type; -DROP INDEX IF EXISTS state_events_state_key; - --- room_id is indexed elsewhere -DROP INDEX IF EXISTS current_state_events_room_id; -DROP INDEX IF EXISTS current_state_events_state_key; -DROP INDEX IF EXISTS current_state_events_type; - -DROP INDEX IF EXISTS transactions_have_ref; - --- (topological_ordering, stream_ordering, room_id) seems like a strange index, --- and is used incredibly rarely. -DROP INDEX IF EXISTS events_order_topo_stream_room; - --- an equivalent index to this actually gets re-created in delta 41, because it --- turned out that deleting it wasn't a great plan :/. In any case, let's --- delete it here, and delta 41 will create a new one with an added UNIQUE --- constraint -DROP INDEX IF EXISTS event_search_ev_idx; -""" - -POSTGRES_DROP_CONSTRAINT = """ -ALTER TABLE event_auth DROP CONSTRAINT IF EXISTS event_auth_event_id_auth_id_room_id_key; -""" - -SQLITE_DROP_CONSTRAINT = """ -DROP INDEX IF EXISTS evauth_edges_id; - -CREATE TABLE IF NOT EXISTS event_auth_new( - event_id TEXT NOT NULL, - auth_id TEXT NOT NULL, - room_id TEXT NOT NULL -); - -INSERT INTO event_auth_new - SELECT event_id, auth_id, room_id - FROM event_auth; - -DROP TABLE event_auth; - -ALTER TABLE event_auth_new RENAME TO event_auth; - -CREATE INDEX evauth_edges_id ON event_auth(event_id); -""" - - -def run_create(cur, database_engine, *args, **kwargs): - for statement in get_statements(DROP_INDICES.splitlines()): - cur.execute(statement) - - if isinstance(database_engine, PostgresEngine): - drop_constraint = POSTGRES_DROP_CONSTRAINT - else: - drop_constraint = SQLITE_DROP_CONSTRAINT - - for statement in get_statements(drop_constraint.splitlines()): - cur.execute(statement) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql deleted file mode 100644 index cf7a90dd10..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Update any email addresses that were stored with mixed case into all - * lowercase - */ - - -- There may be "duplicate" emails (with different case) already in the table, - -- so we find them and move all but the most recently used account. - UPDATE user_threepids - SET medium = 'email_old' - WHERE medium = 'email' - AND address IN ( - -- We select all the addresses that are linked to the user_id that is NOT - -- the most recently created. - SELECT u.address - FROM - user_threepids AS u, - -- `duplicate_addresses` is a table of all the email addresses that - -- appear multiple times and when the binding was created - ( - SELECT lower(u1.address) AS address, max(u1.added_at) AS max_ts - FROM user_threepids AS u1 - INNER JOIN user_threepids AS u2 ON u1.medium = u2.medium AND lower(u1.address) = lower(u2.address) AND u1.address != u2.address - WHERE u1.medium = 'email' AND u2.medium = 'email' - GROUP BY lower(u1.address) - ) AS duplicate_addresses - WHERE - lower(u.address) = duplicate_addresses.address - AND u.added_at != max_ts -- NOT the most recently created - ); - - --- This update is now safe since we've removed the duplicate addresses. -UPDATE user_threepids SET address = LOWER(address) WHERE medium = 'email'; - - -/* Add an index for the select we do on passwored reset */ -CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); diff --git a/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql deleted file mode 100644 index 515e6b8e84..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We no longer do this given we back it out again in schema 47 - --- INSERT into background_updates (update_name, progress_json) --- VALUES ('event_search_postgres_gist', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql deleted file mode 100644 index 74bdc49073..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE appservice_room_list( - appservice_id TEXT NOT NULL, - network_id TEXT NOT NULL, - room_id TEXT NOT NULL -); - --- Each appservice can have multiple published room lists associated with them, --- keyed of a particular network_id -CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( - appservice_id, network_id, room_id -); - -ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT; -ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql deleted file mode 100644 index 00be801e90..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql deleted file mode 100644 index de2ad93e5c..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('event_push_actions_highlights_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql deleted file mode 100644 index 5af814290b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - CREATE TABLE federation_stream_position( - type TEXT NOT NULL, - stream_id INTEGER NOT NULL - ); - - INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); - INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; diff --git a/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql deleted file mode 100644 index 1bf911c8ab..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE room_memberships ADD COLUMN display_name TEXT; -ALTER TABLE room_memberships ADD COLUMN avatar_url TEXT; - -INSERT into background_updates (update_name, progress_json) - VALUES ('room_membership_profile_update', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql deleted file mode 100644 index 7ffa189f39..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('current_state_members_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql deleted file mode 100644 index b9fe1f0480..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- turn the pre-fill startup query into a index-only scan on postgresql. -INSERT into background_updates (update_name, progress_json) - VALUES ('device_inbox_stream_index', '{}'); - -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index'); diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql deleted file mode 100644 index dd6dcb65f1..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright 2017 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Cache of remote devices. -CREATE TABLE device_lists_remote_cache ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - content TEXT NOT NULL -); - --- The last update we got for a user. Empty if we're not receiving updates for --- that user. -CREATE TABLE device_lists_remote_extremeties ( - user_id TEXT NOT NULL, - stream_id TEXT NOT NULL -); - --- we used to create non-unique indexes on these tables, but as of update 52 we create --- unique indexes concurrently: --- --- CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); --- CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); - - --- Stream of device lists updates. Includes both local and remotes -CREATE TABLE device_lists_stream ( - stream_id BIGINT NOT NULL, - user_id TEXT NOT NULL, - device_id TEXT NOT NULL -); - -CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); - - --- The stream of updates to send to other servers. We keep at least one row --- per user that was sent so that the prev_id for any new updates can be --- calculated -CREATE TABLE device_lists_outbound_pokes ( - destination TEXT NOT NULL, - stream_id BIGINT NOT NULL, - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - sent BOOLEAN NOT NULL, - ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers -); - -CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); -CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql deleted file mode 100644 index 3918f0b794..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2017 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Aggregate of old notification counts that have been deleted out of the --- main event_push_actions table. This count does not include those that were --- highlights, as they remain in the event_push_actions table. -CREATE TABLE event_push_summary ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - notif_count BIGINT NOT NULL, - stream_ordering BIGINT NOT NULL -); - -CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); - - --- The stream ordering up to which we have aggregated the event_push_actions --- table into event_push_summary -CREATE TABLE event_push_summary_stream_ordering ( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_ordering BIGINT NOT NULL, - CHECK (Lock='X') -); - -INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); diff --git a/synapse/storage/data_stores/main/schema/delta/40/pushers.sql b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql deleted file mode 100644 index 054a223f14..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/40/pushers.sql +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS pushers2 ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - access_token BIGINT DEFAULT NULL, - profile_tag TEXT NOT NULL, - kind TEXT NOT NULL, - app_id TEXT NOT NULL, - app_display_name TEXT NOT NULL, - device_display_name TEXT NOT NULL, - pushkey TEXT NOT NULL, - ts BIGINT NOT NULL, - lang TEXT, - data TEXT, - last_stream_ordering INTEGER, - last_success BIGINT, - failing_since BIGINT, - UNIQUE (app_id, pushkey, user_name) -); - -INSERT INTO pushers2 SELECT * FROM PUSHERS; - -DROP TABLE PUSHERS; - -ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql deleted file mode 100644 index b7bee8b692..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('device_lists_stream_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql deleted file mode 100644 index 62f0b9892b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql deleted file mode 100644 index 5d9cfecf36..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('event_search_event_id_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql deleted file mode 100644 index a194bf0238..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE ratelimit_override ( - user_id TEXT NOT NULL, - messages_per_second BIGINT, - burst_count BIGINT -); - -CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql deleted file mode 100644 index d28851aff8..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE current_state_delta_stream ( - stream_id BIGINT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - event_id TEXT, -- Is null if the key was removed - prev_event_id TEXT -- Is null if the key was added -); - -CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql deleted file mode 100644 index 9ab8c14fa3..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - --- Table of last stream_id that we sent to destination for user_id. This is --- used to fill out the `prev_id` fields of outbound device list updates. -CREATE TABLE device_lists_outbound_last_success ( - destination TEXT NOT NULL, - user_id TEXT NOT NULL, - stream_id BIGINT NOT NULL -); - -INSERT INTO device_lists_outbound_last_success - SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_pokes - WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values - GROUP BY destination, user_id; - -CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( - destination, user_id, stream_id -); diff --git a/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql deleted file mode 100644 index b8821ac759..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('event_auth_state_only', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/42/user_dir.py b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py deleted file mode 100644 index 506f326f4d..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/42/user_dir.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -BOTH_TABLES = """ -CREATE TABLE user_directory_stream_pos ( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_id BIGINT, - CHECK (Lock='X') -); - -INSERT INTO user_directory_stream_pos (stream_id) VALUES (null); - -CREATE TABLE user_directory ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, -- A room_id that we know the user is joined to - display_name TEXT, - avatar_url TEXT -); - -CREATE INDEX user_directory_room_idx ON user_directory(room_id); -CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); - -CREATE TABLE users_in_pubic_room ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL -- A room_id that we know is public -); - -CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id); -CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id); -""" - - -POSTGRES_TABLE = """ -CREATE TABLE user_directory_search ( - user_id TEXT NOT NULL, - vector tsvector -); - -CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector); -CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id); -""" - - -SQLITE_TABLE = """ -CREATE VIRTUAL TABLE user_directory_search - USING fts4 ( user_id, value ); -""" - - -def run_create(cur, database_engine, *args, **kwargs): - for statement in get_statements(BOTH_TABLES.splitlines()): - cur.execute(statement) - - if isinstance(database_engine, PostgresEngine): - for statement in get_statements(POSTGRES_TABLE.splitlines()): - cur.execute(statement) - elif isinstance(database_engine, Sqlite3Engine): - for statement in get_statements(SQLITE_TABLE.splitlines()): - cur.execute(statement) - else: - raise Exception("Unrecognized database engine") - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql deleted file mode 100644 index 0e3cd143ff..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE blocked_rooms ( - room_id TEXT NOT NULL, - user_id TEXT NOT NULL -- Admin who blocked the room -); - -CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql deleted file mode 100644 index 630907ec4f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT; -ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql deleted file mode 100644 index 45ebe020da..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/43/user_share.sql b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql deleted file mode 100644 index ee7062abe4..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/43/user_share.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Table keeping track of who shares a room with who. We only keep track --- of this for local users, so `user_id` is local users only (but we do keep track --- of which remote users share a room) -CREATE TABLE users_who_share_rooms ( - user_id TEXT NOT NULL, - other_user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room -); - - -CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id); -CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id); -CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id); - - --- Make sure that we populate the table initially -UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql deleted file mode 100644 index b12f9b2ebf..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was --- removed and replaced with 46/local_media_repository_url_idx.sql. --- --- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; - --- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support --- indices on expressions until 3.9. -CREATE TABLE local_media_repository_url_cache_new( - url TEXT, - response_code INTEGER, - etag TEXT, - expires_ts BIGINT, - og TEXT, - media_id TEXT, - download_ts BIGINT -); - -INSERT INTO local_media_repository_url_cache_new - SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache; - -DROP TABLE local_media_repository_url_cache; -ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache; - -CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); -CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); -CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); diff --git a/synapse/storage/data_stores/main/schema/delta/45/group_server.sql b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql deleted file mode 100644 index b2333848a0..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/45/group_server.sql +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE groups ( - group_id TEXT NOT NULL, - name TEXT, -- the display name of the room - avatar_url TEXT, - short_description TEXT, - long_description TEXT -); - -CREATE UNIQUE INDEX groups_idx ON groups(group_id); - - --- list of users the group server thinks are joined -CREATE TABLE group_users ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - is_admin BOOLEAN NOT NULL, - is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone -); - - -CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id); -CREATE INDEX groups_users_u_idx ON group_users(user_id); - --- list of users the group server thinks are invited -CREATE TABLE group_invites ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL -); - -CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id); -CREATE INDEX groups_invites_u_idx ON group_invites(user_id); - - -CREATE TABLE group_rooms ( - group_id TEXT NOT NULL, - room_id TEXT NOT NULL, - is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone -); - -CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id); -CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id); - - --- Rooms to include in the summary -CREATE TABLE group_summary_rooms ( - group_id TEXT NOT NULL, - room_id TEXT NOT NULL, - category_id TEXT NOT NULL, - room_order BIGINT NOT NULL, - is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone - UNIQUE (group_id, category_id, room_id, room_order), - CHECK (room_order > 0) -); - -CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); - - --- Categories to include in the summary -CREATE TABLE group_summary_room_categories ( - group_id TEXT NOT NULL, - category_id TEXT NOT NULL, - cat_order BIGINT NOT NULL, - UNIQUE (group_id, category_id, cat_order), - CHECK (cat_order > 0) -); - --- The categories in the group -CREATE TABLE group_room_categories ( - group_id TEXT NOT NULL, - category_id TEXT NOT NULL, - profile TEXT NOT NULL, - is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone - UNIQUE (group_id, category_id) -); - --- The users to include in the group summary -CREATE TABLE group_summary_users ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - role_id TEXT NOT NULL, - user_order BIGINT NOT NULL, - is_public BOOLEAN NOT NULL -- whether the user should be show to everyone -); - -CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); - --- The roles to include in the group summary -CREATE TABLE group_summary_roles ( - group_id TEXT NOT NULL, - role_id TEXT NOT NULL, - role_order BIGINT NOT NULL, - UNIQUE (group_id, role_id, role_order), - CHECK (role_order > 0) -); - - --- The roles in a groups -CREATE TABLE group_roles ( - group_id TEXT NOT NULL, - role_id TEXT NOT NULL, - profile TEXT NOT NULL, - is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone - UNIQUE (group_id, role_id) -); - - --- List of attestations we've given out and need to renew -CREATE TABLE group_attestations_renewals ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - valid_until_ms BIGINT NOT NULL -); - -CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); -CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); -CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); - - --- List of attestations we've received from remotes and are interested in. -CREATE TABLE group_attestations_remote ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - valid_until_ms BIGINT NOT NULL, - attestation_json TEXT NOT NULL -); - -CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); -CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); -CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); - - --- The group membership for the HS's users -CREATE TABLE local_group_membership ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - is_admin BOOLEAN NOT NULL, - membership TEXT NOT NULL, - is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership - content TEXT NOT NULL -); - -CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); -CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); - - -CREATE TABLE local_group_updates ( - stream_id BIGINT NOT NULL, - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - type TEXT NOT NULL, - content TEXT NOT NULL -); diff --git a/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql deleted file mode 100644 index e5ddc84df0..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - --- A subset of remote users whose profiles we have cached. --- Whether a user is in this table or not is defined by the storage function --- `is_subscribed_remote_profile_for_user` -CREATE TABLE remote_profile_cache ( - user_id TEXT NOT NULL, - displayname TEXT, - avatar_url TEXT, - last_check BIGINT NOT NULL -); - -CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); -CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql deleted file mode 100644 index 68c48a89a9..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* we no longer use (or create) the refresh_tokens table */ -DROP TABLE IF EXISTS refresh_tokens; diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql deleted file mode 100644 index bb307889c1..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- drop the unique constraint on deleted_pushers so that we can just insert --- into it rather than upserting. - -CREATE TABLE deleted_pushers2 ( - stream_id BIGINT NOT NULL, - app_id TEXT NOT NULL, - pushkey TEXT NOT NULL, - user_id TEXT NOT NULL -); - -INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id) - SELECT stream_id, app_id, pushkey, user_id from deleted_pushers; - -DROP TABLE deleted_pushers; -ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers; - --- create the index after doing the inserts because that's more efficient. --- it also means we can give it the same name as the old one without renaming. -CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); - diff --git a/synapse/storage/data_stores/main/schema/delta/46/group_server.sql b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql deleted file mode 100644 index 097679bc9a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/46/group_server.sql +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE groups_new ( - group_id TEXT NOT NULL, - name TEXT, -- the display name of the room - avatar_url TEXT, - short_description TEXT, - long_description TEXT, - is_public BOOL NOT NULL -- whether non-members can access group APIs -); - --- NB: awful hack to get the default to be true on postgres and 1 on sqlite -INSERT INTO groups_new - SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups; - -DROP TABLE groups; -ALTER TABLE groups_new RENAME TO groups; - -CREATE UNIQUE INDEX groups_idx ON groups(group_id); diff --git a/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql deleted file mode 100644 index bbfc7f5d1a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- register a background update which will recreate the --- local_media_repository_url_idx index. --- --- We do this as a bg update not because it is a particularly onerous --- operation, but because we'd like it to be a partial index if possible, and --- the background_index_update code will understand whether we are on --- postgres or sqlite and behave accordingly. -INSERT INTO background_updates (update_name, progress_json) VALUES - ('local_media_repository_url_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql deleted file mode 100644 index cb0d5a2576..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- change the user_directory table to also cover global local user profiles --- rather than just profiles within specific rooms. - -CREATE TABLE user_directory2 ( - user_id TEXT NOT NULL, - room_id TEXT, - display_name TEXT, - avatar_url TEXT -); - -INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url) - SELECT user_id, room_id, display_name, avatar_url from user_directory; - -DROP TABLE user_directory; -ALTER TABLE user_directory2 RENAME TO user_directory; - --- create indexes after doing the inserts because that's more efficient. --- it also means we can give it the same name as the old one without renaming. -CREATE INDEX user_directory_room_idx ON user_directory(room_id); -CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql deleted file mode 100644 index d9505f8da1..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- this is just embarassing :| -ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms; - --- this is only 300K rows on matrix.org and takes ~3s to generate the index, --- so is hopefully not going to block anyone else for that long... -CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id); -CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id); -DROP INDEX users_in_pubic_room_room_idx; -DROP INDEX users_in_pubic_room_user_idx; diff --git a/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql deleted file mode 100644 index f505fb22b5..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql deleted file mode 100644 index 31d7a817eb..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('event_search_postgres_gin', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql deleted file mode 100644 index edccf4a96f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Temporary staging area for push actions that have been calculated for an --- event, but the event hasn't yet been persisted. --- When the event is persisted the rows are moved over to the --- event_push_actions table. -CREATE TABLE event_push_actions_staging ( - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - actions TEXT NOT NULL, - notif SMALLINT NOT NULL, - highlight SMALLINT NOT NULL -); - -CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql deleted file mode 100644 index 5237491506..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* record the version of the privacy policy the user has consented to - */ -ALTER TABLE users ADD COLUMN consent_version TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql deleted file mode 100644 index 9248b0b24a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('user_ips_last_seen_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql deleted file mode 100644 index e9013a6969..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Store any accounts that have been requested to be deactivated. - * We part the account from all the rooms its in when its - * deactivated. This can take some time and synapse may be restarted - * before it completes, so store the user IDs here until the process - * is complete. - */ -CREATE TABLE users_pending_deactivation ( - user_id TEXT NOT NULL -); diff --git a/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py deleted file mode 100644 index 49f5f2c003..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.engines import PostgresEngine -from synapse.storage.prepare_database import get_statements - -FIX_INDEXES = """ --- rebuild indexes as uniques -DROP INDEX groups_invites_g_idx; -CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); -DROP INDEX groups_users_g_idx; -CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); - --- rename other indexes to actually match their table names.. -DROP INDEX groups_users_u_idx; -CREATE INDEX group_users_u_idx ON group_users(user_id); -DROP INDEX groups_invites_u_idx; -CREATE INDEX group_invites_u_idx ON group_invites(user_id); -DROP INDEX groups_rooms_g_idx; -CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); -DROP INDEX groups_rooms_r_idx; -CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); -""" - - -def run_create(cur, database_engine, *args, **kwargs): - rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" - - # remove duplicates from group_users & group_invites tables - cur.execute( - """ - DELETE FROM group_users WHERE %s NOT IN ( - SELECT min(%s) FROM group_users GROUP BY group_id, user_id - ); - """ - % (rowid, rowid) - ) - cur.execute( - """ - DELETE FROM group_invites WHERE %s NOT IN ( - SELECT min(%s) FROM group_invites GROUP BY group_id, user_id - ); - """ - % (rowid, rowid) - ) - - for statement in get_statements(FIX_INDEXES.splitlines()): - cur.execute(statement) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql deleted file mode 100644 index ce26eaf0c9..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * This isn't a real ENUM because sqlite doesn't support it - * and we use a default of NULL for inserted rows and interpret - * NULL at the python store level as necessary so that existing - * rows are given the correct default policy. - */ -ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite'; diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql deleted file mode 100644 index 14dcf18d73..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* record whether we have sent a server notice about consenting to the - * privacy policy. Specifically records the version of the policy we sent - * a message about. - */ -ALTER TABLE users ADD COLUMN consent_server_notice_sent TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql deleted file mode 100644 index 3dd478196f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, - device_id TEXT, - timestamp BIGINT NOT NULL ); -CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); -CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql deleted file mode 100644 index 3a4ed59b5b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('user_ips_last_seen_only_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql deleted file mode 100644 index c93ae47532..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - - -INSERT into background_updates (update_name, progress_json) - VALUES ('users_creation_ts', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql deleted file mode 100644 index 5d8641a9ab..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- a table of users who have requested that their details be erased -CREATE TABLE erased_users ( - user_id TEXT NOT NULL -); - -CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py deleted file mode 100644 index b1684a8441..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py +++ /dev/null @@ -1,96 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -We want to stop populating 'event.content', so we need to make it nullable. - -If this has to be rolled back, then the following should populate the missing data: - -Postgres: - - UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej - WHERE ej.event_id = events.event_id AND - stream_ordering < ( - SELECT stream_ordering FROM events WHERE content IS NOT NULL - ORDER BY stream_ordering LIMIT 1 - ); - - UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej - WHERE ej.event_id = events.event_id AND - stream_ordering > ( - SELECT stream_ordering FROM events WHERE content IS NOT NULL - ORDER BY stream_ordering DESC LIMIT 1 - ); - -SQLite: - - UPDATE events SET content=( - SELECT json_extract(json,'$.content') FROM event_json ej - WHERE ej.event_id = events.event_id - ) - WHERE - stream_ordering < ( - SELECT stream_ordering FROM events WHERE content IS NOT NULL - ORDER BY stream_ordering LIMIT 1 - ) - OR stream_ordering > ( - SELECT stream_ordering FROM events WHERE content IS NOT NULL - ORDER BY stream_ordering DESC LIMIT 1 - ); - -""" - -import logging - -from synapse.storage.engines import PostgresEngine - -logger = logging.getLogger(__name__) - - -def run_create(cur, database_engine, *args, **kwargs): - pass - - -def run_upgrade(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - cur.execute( - """ - ALTER TABLE events ALTER COLUMN content DROP NOT NULL; - """ - ) - return - - # sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html - - cur.execute( - "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'" - ) - (oldsql,) = cur.fetchone() - - sql = oldsql.replace("content TEXT NOT NULL", "content TEXT") - if sql == oldsql: - raise Exception("Couldn't find null constraint to drop in %s" % oldsql) - - logger.info("Replacing definition of 'events' with: %s", sql) - - cur.execute("PRAGMA schema_version") - (oldver,) = cur.fetchone() - cur.execute("PRAGMA writable_schema=ON") - cur.execute( - "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", - (sql,), - ) - cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) - cur.execute("PRAGMA writable_schema=OFF") diff --git a/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql deleted file mode 100644 index c0e66a697d..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- users' optionally backed up encrypted e2e sessions -CREATE TABLE e2e_room_keys ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - session_id TEXT NOT NULL, - version TEXT NOT NULL, - first_message_index INT, - forwarded_count INT, - is_verified BOOLEAN, - session_data TEXT NOT NULL -); - -CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); - --- the metadata for each generation of encrypted e2e session backups -CREATE TABLE e2e_room_keys_versions ( - user_id TEXT NOT NULL, - version TEXT NOT NULL, - algorithm TEXT NOT NULL, - auth_data TEXT NOT NULL, - deleted SMALLINT DEFAULT 0 NOT NULL -); - -CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); diff --git a/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql deleted file mode 100644 index c9d537d5a3..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- a table of monthly active users, for use where blocking based on mau limits -CREATE TABLE monthly_active_users ( - user_id TEXT NOT NULL, - -- Last time we saw the user. Not guaranteed to be accurate due to rate limiting - -- on updates, Granularity of updates governed by - -- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY - -- Measured in ms since epoch. - timestamp BIGINT NOT NULL -); - -CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); -CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); diff --git a/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql deleted file mode 100644 index 91e03d13e1..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- This is needed to efficiently check for unreferenced state groups during --- purge. Added events_to_state_group(state_group) index -INSERT into background_updates (update_name, progress_json) - VALUES ('event_to_state_groups_sg_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql deleted file mode 100644 index bfa49e6f92..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- register a background update which will create a unique index on --- device_lists_remote_cache -INSERT into background_updates (update_name, progress_json) - VALUES ('device_lists_remote_cache_unique_idx', '{}'); - --- and one on device_lists_remote_extremeties -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ( - 'device_lists_remote_extremeties_unique_idx', '{}', - - -- doesn't really depend on this, but we need to make sure both happen - -- before we drop the old indexes. - 'device_lists_remote_cache_unique_idx' - ); - --- once they complete, we can drop the old indexes. -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ( - 'drop_device_list_streams_non_unique_indexes', '{}', - 'device_lists_remote_extremeties_unique_idx' - ); diff --git a/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql deleted file mode 100644 index db687cccae..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Change version column to an integer so we can do MAX() sensibly - */ -CREATE TABLE e2e_room_keys_versions_new ( - user_id TEXT NOT NULL, - version BIGINT NOT NULL, - algorithm TEXT NOT NULL, - auth_data TEXT NOT NULL, - deleted SMALLINT DEFAULT 0 NOT NULL -); - -INSERT INTO e2e_room_keys_versions_new - SELECT user_id, CAST(version as BIGINT), algorithm, auth_data, deleted FROM e2e_room_keys_versions; - -DROP TABLE e2e_room_keys_versions; -ALTER TABLE e2e_room_keys_versions_new RENAME TO e2e_room_keys_versions; - -CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); - -/* Change e2e_rooms_keys to match - */ -CREATE TABLE e2e_room_keys_new ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - session_id TEXT NOT NULL, - version BIGINT NOT NULL, - first_message_index INT, - forwarded_count INT, - is_verified BOOLEAN, - session_data TEXT NOT NULL -); - -INSERT INTO e2e_room_keys_new - SELECT user_id, room_id, session_id, CAST(version as BIGINT), first_message_index, forwarded_count, is_verified, session_data FROM e2e_room_keys; - -DROP TABLE e2e_room_keys; -ALTER TABLE e2e_room_keys_new RENAME TO e2e_room_keys; - -CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); diff --git a/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql deleted file mode 100644 index 88ec2f83e5..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* The type of the user: NULL for a regular user, or one of the constants in - * synapse.api.constants.UserTypes - */ -ALTER TABLE users ADD COLUMN user_type TEXT DEFAULT NULL; diff --git a/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql deleted file mode 100644 index e372f5a44a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP TABLE IF EXISTS sent_transactions; diff --git a/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql deleted file mode 100644 index 1d977c2834..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE event_json ADD COLUMN format_version INTEGER; diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql deleted file mode 100644 index ffcc896b58..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Set up staging tables -INSERT INTO background_updates (update_name, progress_json) VALUES - ('populate_user_directory_createtables', '{}'); - --- Run through each room and update the user directory according to who is in it -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_user_directory_process_rooms', '{}', 'populate_user_directory_createtables'); - --- Insert all users, if search_all_users is on -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_user_directory_process_users', '{}', 'populate_user_directory_process_rooms'); - --- Clean up staging tables -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_user_directory_cleanup', '{}', 'populate_user_directory_process_users'); diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql deleted file mode 100644 index b812c5794f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -- analyze user_ips, to help ensure the correct indices are used -INSERT INTO background_updates (update_name, progress_json) VALUES - ('user_ips_analyze', '{}'); - --- delete duplicates -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('user_ips_remove_dupes', '{}', 'user_ips_analyze'); - --- add a new unique index to user_ips table -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('user_ips_device_unique_index', '{}', 'user_ips_remove_dupes'); - --- drop the old original index -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index'); diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_share.sql b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql deleted file mode 100644 index 5831b1a6f8..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/53/user_share.sql +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd, 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Old disused version of the tables below. -DROP TABLE IF EXISTS users_who_share_rooms; - --- Tables keeping track of what users share rooms. This is a map of local users --- to local or remote users, per room. Remote users cannot be in the user_id --- column, only the other_user_id column. There are two tables, one for public --- rooms and those for private rooms. -CREATE TABLE IF NOT EXISTS users_who_share_public_rooms ( - user_id TEXT NOT NULL, - other_user_id TEXT NOT NULL, - room_id TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS users_who_share_private_rooms ( - user_id TEXT NOT NULL, - other_user_id TEXT NOT NULL, - room_id TEXT NOT NULL -); - -CREATE UNIQUE INDEX users_who_share_public_rooms_u_idx ON users_who_share_public_rooms(user_id, other_user_id, room_id); -CREATE INDEX users_who_share_public_rooms_r_idx ON users_who_share_public_rooms(room_id); -CREATE INDEX users_who_share_public_rooms_o_idx ON users_who_share_public_rooms(other_user_id); - -CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id); -CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id); -CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id); - --- Make sure that we populate the tables initially by resetting the stream ID -UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql deleted file mode 100644 index 80c2c573b6..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Tracks which identity server a user bound their threepid via. -CREATE TABLE user_threepid_id_server ( - user_id TEXT NOT NULL, - medium TEXT NOT NULL, - address TEXT NOT NULL, - id_server TEXT NOT NULL -); - -CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server( - user_id, medium, address, id_server -); - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('user_threepids_grandfather', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql deleted file mode 100644 index f7827ca6d2..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We don't need the old version of this table. -DROP TABLE IF EXISTS users_in_public_rooms; - --- Old version of users_in_public_rooms -DROP TABLE IF EXISTS users_who_share_public_rooms; - --- Track what users are in public rooms. -CREATE TABLE IF NOT EXISTS users_in_public_rooms ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL -); - -CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql deleted file mode 100644 index 0adb2ad55e..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We previously changed the schema for this table without renaming the file, which means --- that some databases might still be using the old schema. This ensures Synapse uses the --- right schema for the table. -DROP TABLE IF EXISTS account_validity; - --- Track what users are in public rooms. -CREATE TABLE IF NOT EXISTS account_validity ( - user_id TEXT PRIMARY KEY, - expiration_ts_ms BIGINT NOT NULL, - email_sent BOOLEAN NOT NULL, - renewal_token TEXT -); - -CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms) -CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token) diff --git a/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql deleted file mode 100644 index c01aa9d2d9..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* When we can use this key until, before we have to refresh it. */ -ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT; - -UPDATE server_signature_keys SET ts_valid_until_ms = ( - SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE - skj.server_name = server_signature_keys.server_name AND - skj.key_id = server_signature_keys.key_id -); diff --git a/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql deleted file mode 100644 index b062ec840c..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Start a background job to cleanup extremities that were incorrectly added --- by bug #5269. -INSERT INTO background_updates (update_name, progress_json) VALUES - ('delete_soft_failed_extremities', '{}'); - -DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent. -CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; -CREATE INDEX _extremities_to_check_id ON _extremities_to_check(event_id); diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql deleted file mode 100644 index dbbe682697..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- we need to do this first due to foreign constraints -DROP TABLE IF EXISTS application_services_regex; - -DROP TABLE IF EXISTS application_services; -DROP TABLE IF EXISTS transaction_id_to_pdu; -DROP TABLE IF EXISTS stats_reporting; -DROP TABLE IF EXISTS current_state_resets; -DROP TABLE IF EXISTS event_content_hashes; -DROP TABLE IF EXISTS event_destinations; -DROP TABLE IF EXISTS event_edge_hashes; -DROP TABLE IF EXISTS event_signatures; -DROP TABLE IF EXISTS feedback; -DROP TABLE IF EXISTS room_hosts; -DROP TABLE IF EXISTS server_tls_certificates; -DROP TABLE IF EXISTS state_forward_extremities; diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql deleted file mode 100644 index e6ee70c623..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP TABLE IF EXISTS presence_list; diff --git a/synapse/storage/data_stores/main/schema/delta/54/relations.sql b/synapse/storage/data_stores/main/schema/delta/54/relations.sql deleted file mode 100644 index 134862b870..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/54/relations.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Tracks related events, like reactions, replies, edits, etc. Note that things --- in this table are not necessarily "valid", e.g. it may contain edits from --- people who don't have power to edit other peoples events. -CREATE TABLE IF NOT EXISTS event_relations ( - event_id TEXT NOT NULL, - relates_to_id TEXT NOT NULL, - relation_type TEXT NOT NULL, - aggregation_key TEXT -); - -CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); -CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats.sql b/synapse/storage/data_stores/main/schema/delta/54/stats.sql deleted file mode 100644 index 652e58308e..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/54/stats.sql +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE stats_stream_pos ( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_id BIGINT, - CHECK (Lock='X') -); - -INSERT INTO stats_stream_pos (stream_id) VALUES (null); - -CREATE TABLE user_stats ( - user_id TEXT NOT NULL, - ts BIGINT NOT NULL, - bucket_size INT NOT NULL, - public_rooms INT NOT NULL, - private_rooms INT NOT NULL -); - -CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); - -CREATE TABLE room_stats ( - room_id TEXT NOT NULL, - ts BIGINT NOT NULL, - bucket_size INT NOT NULL, - current_state_events INT NOT NULL, - joined_members INT NOT NULL, - invited_members INT NOT NULL, - left_members INT NOT NULL, - banned_members INT NOT NULL, - state_events INT NOT NULL -); - -CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); - --- cache of current room state; useful for the publicRooms list -CREATE TABLE room_state ( - room_id TEXT NOT NULL, - join_rules TEXT, - history_visibility TEXT, - encryption TEXT, - name TEXT, - topic TEXT, - avatar TEXT, - canonical_alias TEXT - -- get aliases straight from the right table -); - -CREATE UNIQUE INDEX room_state_room ON room_state(room_id); - -CREATE TABLE room_stats_earliest_token ( - room_id TEXT NOT NULL, - token BIGINT NOT NULL -); - -CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); - --- Set up staging tables -INSERT INTO background_updates (update_name, progress_json) VALUES - ('populate_stats_createtables', '{}'); - --- Run through each room and update stats -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_process_rooms', '{}', 'populate_stats_createtables'); - --- Clean up staging tables -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats2.sql b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql deleted file mode 100644 index 3b2d48447f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/54/stats2.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- This delta file gets run after `54/stats.sql` delta. - --- We want to add some indices to the temporary stats table, so we re-insert --- 'populate_stats_createtables' if we are still processing the rooms update. -INSERT INTO background_updates (update_name, progress_json) - SELECT 'populate_stats_createtables', '{}' - WHERE - 'populate_stats_process_rooms' IN ( - SELECT update_name FROM background_updates - ) - AND 'populate_stats_createtables' NOT IN ( -- don't insert if already exists - SELECT update_name FROM background_updates - ); diff --git a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql deleted file mode 100644 index 4590604bfd..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- when this access token can be used until, in ms since the epoch. NULL means the token --- never expires. -ALTER TABLE access_tokens ADD COLUMN valid_until_ms BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql deleted file mode 100644 index a8eced2e0a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS threepid_validation_session ( - session_id TEXT PRIMARY KEY, - medium TEXT NOT NULL, - address TEXT NOT NULL, - client_secret TEXT NOT NULL, - last_send_attempt BIGINT NOT NULL, - validated_at BIGINT -); - -CREATE TABLE IF NOT EXISTS threepid_validation_token ( - token TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - next_link TEXT, - expires BIGINT NOT NULL -); - -CREATE INDEX threepid_validation_token_session_id ON threepid_validation_token(session_id); diff --git a/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql deleted file mode 100644 index dabdde489b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE users ADD deactivated SMALLINT DEFAULT 0 NOT NULL; - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('users_set_deactivated_flag', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql deleted file mode 100644 index 41807eb1e7..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Opentracing context data for inclusion in the device_list_update EDUs, as a - * json-encoded dictionary. NULL if opentracing is disabled (or not enabled for this destination). - */ -ALTER TABLE device_lists_outbound_pokes ADD opentracing_context TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql deleted file mode 100644 index 473018676f..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We add membership to current state so that we don't need to join against --- room_memberships, which can be surprisingly costly (we do such queries --- very frequently). --- This will be null for non-membership events and the content.membership key --- for membership events. (Will also be null for membership events until the --- background update job has finished). -ALTER TABLE current_state_events ADD membership TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql deleted file mode 100644 index 3133d42d4a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We add membership to current state so that we don't need to join against --- room_memberships, which can be surprisingly costly (we do such queries --- very frequently). --- This will be null for non-membership events and the content.membership key --- for membership events. (Will also be null for membership events until the --- background update job has finished). - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('current_state_events_membership', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql deleted file mode 100644 index 1d2ddb1b1a..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* delete room keys that belong to deleted room key version, or to room key - * versions that don't exist (anymore) - */ -DELETE FROM e2e_room_keys -WHERE version NOT IN ( - SELECT version - FROM e2e_room_keys_versions - WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id - AND e2e_room_keys_versions.deleted = 0 -); diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql deleted file mode 100644 index f00889290b..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Record the timestamp when a given server started failing - */ -ALTER TABLE destinations ADD failure_ts BIGINT; - -/* as a rough approximation, we assume that the server started failing at - * retry_interval before the last retry - */ -UPDATE destinations SET failure_ts = retry_last_ts - retry_interval - WHERE retry_last_ts > 0; diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres deleted file mode 100644 index b9bbb18a91..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We want to store large retry intervals so we upgrade the column from INT --- to BIGINT. We don't need to do this on SQLite. -ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql deleted file mode 100644 index c2f557fde9..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- This line already existed in deltas/35/device_stream_id but was not included in the --- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist -INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS ( - SELECT * from device_max_stream_id -); \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql deleted file mode 100644 index dfa902d0ba..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2019 Matrix.org Foundation CIC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Track last seen information for a device in the devices table, rather --- than relying on it being in the user_ips table (which we want to be able --- to purge old entries from) -ALTER TABLE devices ADD COLUMN last_seen BIGINT; -ALTER TABLE devices ADD COLUMN ip TEXT; -ALTER TABLE devices ADD COLUMN user_agent TEXT; - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('devices_last_seen', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql deleted file mode 100644 index 9f09922c67..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- these tables are never used. -DROP TABLE IF EXISTS room_names; -DROP TABLE IF EXISTS topics; -DROP TABLE IF EXISTS history_visibility; -DROP TABLE IF EXISTS guest_access; diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql deleted file mode 100644 index 81a36a8b1d..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS event_expiry ( - event_id TEXT PRIMARY KEY, - expiry_ts BIGINT NOT NULL -); - -CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql deleted file mode 100644 index 5e29c1da19..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- room_id and topoligical_ordering are denormalised from the events table in order to --- make the index work. -CREATE TABLE IF NOT EXISTS event_labels ( - event_id TEXT, - label TEXT, - room_id TEXT NOT NULL, - topological_ordering BIGINT NOT NULL, - PRIMARY KEY(event_id, label) -); - - --- This index enables an event pagination looking for a particular label to index the --- event_labels table first, which is much quicker than scanning the events table and then --- filtering by label, if the label is rarely used relative to the size of the room. -CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering); diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql deleted file mode 100644 index 5f5e0499ae..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('event_store_labels', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql deleted file mode 100644 index 014cb3b538..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 Matrix.org Foundation CIC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- version is supposed to be part of the room keys index -CREATE UNIQUE INDEX e2e_room_keys_with_version_idx ON e2e_room_keys(user_id, version, room_id, session_id); -DROP INDEX IF EXISTS e2e_room_keys_idx; diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql deleted file mode 100644 index 67f8b20297..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- device list needs to know which ones are "real" devices, and which ones are --- just used to avoid collisions -ALTER TABLE devices ADD COLUMN hidden BOOLEAN DEFAULT FALSE; diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite deleted file mode 100644 index e8b1fd35d8..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Change the hidden column from a default value of FALSE to a default value of - * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the - * string 'FALSE', which is truthy. - * - * Since sqlite doesn't allow us to just change the default value, we have to - * recreate the table, copy the data, fix the rows that have incorrect data, and - * replace the old table with the new table. - */ - -CREATE TABLE IF NOT EXISTS devices2 ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - display_name TEXT, - last_seen BIGINT, - ip TEXT, - user_agent TEXT, - hidden BOOLEAN DEFAULT 0, - CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) -); - -INSERT INTO devices2 SELECT * FROM devices; - -UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE'; - -DROP TABLE devices; - -ALTER TABLE devices2 RENAME TO devices; diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql deleted file mode 100644 index 4f24c1405d..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2019 Werner Sembach - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made. -DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); -DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); diff --git a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql deleted file mode 100644 index 7be31ffebb..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql deleted file mode 100644 index ea95db0ed7..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false; diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql deleted file mode 100644 index 49ce35d794..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE redactions ADD COLUMN received_ts BIGINT; - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('redactions_received_ts', '{}'); - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('redactions_have_censored_ts_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres deleted file mode 100644 index 67471f3ef5..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - --- There was a bug where we may have updated censored redactions as bytes, --- which can (somehow) cause json to be inserted hex encoded. These updates go --- and undoes any such hex encoded JSON. - -INSERT into background_updates (update_name, progress_json) - VALUES ('event_fix_redactions_bytes_create_index', '{}'); - -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql deleted file mode 100644 index b7550f6f4e..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP INDEX IF EXISTS redactions_have_censored; diff --git a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql deleted file mode 100644 index aeb17813d3..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Now that #6232 is a thing, we can remove old rooms from the directory. -INSERT INTO background_updates (update_name, progress_json) VALUES - ('remove_tombstoned_rooms_from_directory', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql deleted file mode 100644 index 7d70dd071e..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2019 Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- store the current etag of backup version -ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql deleted file mode 100644 index 92ab1f5e65..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Adds an index on room_memberships for fetching all forgotten rooms for a user -INSERT INTO background_updates (update_name, progress_json) VALUES - ('room_membership_forgotten_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql deleted file mode 100644 index ee6cdf7a14..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Tracks the retention policy of a room. --- A NULL max_lifetime or min_lifetime means that the matching property is not defined in --- the room's retention policy state event. --- If a room doesn't have a retention policy state event in its state, both max_lifetime --- and min_lifetime are NULL. -CREATE TABLE IF NOT EXISTS room_retention( - room_id TEXT, - event_id TEXT, - min_lifetime BIGINT, - max_lifetime BIGINT, - - PRIMARY KEY(room_id, event_id) -); - -CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime); - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('insert_room_retention', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql deleted file mode 100644 index 5c5fffcafb..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- cross-signing keys -CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys ( - user_id TEXT NOT NULL, - -- the type of cross-signing key (master, user_signing, or self_signing) - keytype TEXT NOT NULL, - -- the full key information, as a json-encoded dict - keydata TEXT NOT NULL, - -- for keeping the keys in order, so that we can fetch the latest one - stream_id BIGINT NOT NULL -); - -CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, stream_id); - --- cross-signing signatures -CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( - -- user who did the signing - user_id TEXT NOT NULL, - -- key used to sign - key_id TEXT NOT NULL, - -- user who was signed - target_user_id TEXT NOT NULL, - -- device/key that was signed - target_device_id TEXT NOT NULL, - -- the actual signature - signature TEXT NOT NULL -); - --- replaced by the index created in signing_keys_nonunique_signatures.sql --- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); - --- stream of user signature updates -CREATE TABLE IF NOT EXISTS user_signature_stream ( - -- uses the same stream ID as device list stream - stream_id BIGINT NOT NULL, - -- user who did the signing - from_user_id TEXT NOT NULL, - -- list of users who were signed, as a JSON array - user_ids TEXT NOT NULL -); - -CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql deleted file mode 100644 index 0aa90ebf0c..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* The cross-signing signatures index should not be a unique index, because a - * user may upload multiple signatures for the same target user. The previous - * index was unique, so delete it if it's there and create a new non-unique - * index. */ - -DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT -EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql deleted file mode 100644 index bbdde121e8..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql +++ /dev/null @@ -1,156 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - ------ First clean up from previous versions of room stats. - --- First remove old stats stuff -DROP TABLE IF EXISTS room_stats; -DROP TABLE IF EXISTS room_state; -DROP TABLE IF EXISTS room_stats_state; -DROP TABLE IF EXISTS user_stats; -DROP TABLE IF EXISTS room_stats_earliest_tokens; -DROP TABLE IF EXISTS _temp_populate_stats_position; -DROP TABLE IF EXISTS _temp_populate_stats_rooms; -DROP TABLE IF EXISTS stats_stream_pos; - --- Unschedule old background updates if they're still scheduled -DELETE FROM background_updates WHERE update_name IN ( - 'populate_stats_createtables', - 'populate_stats_process_rooms', - 'populate_stats_process_users', - 'populate_stats_cleanup' -); - --- this relies on current_state_events.membership having been populated, so add --- a dependency on current_state_events_membership. -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_process_rooms', '{}', 'current_state_events_membership'); - --- this also relies on current_state_events.membership having been populated, but --- we get that as a side-effect of depending on populate_stats_process_rooms. -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); - ------ Create tables for our version of room stats. - --- single-row table to track position of incremental updates -DROP TABLE IF EXISTS stats_incremental_position; -CREATE TABLE stats_incremental_position ( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_id BIGINT NOT NULL, - CHECK (Lock='X') -); - --- insert a null row and make sure it is the only one. -INSERT INTO stats_incremental_position ( - stream_id -) SELECT COALESCE(MAX(stream_ordering), 0) from events; - --- represents PRESENT room statistics for a room --- only holds absolute fields -DROP TABLE IF EXISTS room_stats_current; -CREATE TABLE room_stats_current ( - room_id TEXT NOT NULL PRIMARY KEY, - - -- These are absolute counts - current_state_events INT NOT NULL, - joined_members INT NOT NULL, - invited_members INT NOT NULL, - left_members INT NOT NULL, - banned_members INT NOT NULL, - - local_users_in_room INT NOT NULL, - - -- The maximum delta stream position that this row takes into account. - completed_delta_stream_id BIGINT NOT NULL -); - - --- represents HISTORICAL room statistics for a room -DROP TABLE IF EXISTS room_stats_historical; -CREATE TABLE room_stats_historical ( - room_id TEXT NOT NULL, - -- These stats cover the time from (end_ts - bucket_size)...end_ts (in ms). - -- Note that end_ts is quantised. - end_ts BIGINT NOT NULL, - bucket_size BIGINT NOT NULL, - - -- These stats are absolute counts - current_state_events BIGINT NOT NULL, - joined_members BIGINT NOT NULL, - invited_members BIGINT NOT NULL, - left_members BIGINT NOT NULL, - banned_members BIGINT NOT NULL, - local_users_in_room BIGINT NOT NULL, - - -- These stats are per time slice - total_events BIGINT NOT NULL, - total_event_bytes BIGINT NOT NULL, - - PRIMARY KEY (room_id, end_ts) -); - --- We use this index to speed up deletion of ancient room stats. -CREATE INDEX room_stats_historical_end_ts ON room_stats_historical (end_ts); - --- represents PRESENT statistics for a user --- only holds absolute fields -DROP TABLE IF EXISTS user_stats_current; -CREATE TABLE user_stats_current ( - user_id TEXT NOT NULL PRIMARY KEY, - - joined_rooms BIGINT NOT NULL, - - -- The maximum delta stream position that this row takes into account. - completed_delta_stream_id BIGINT NOT NULL -); - --- represents HISTORICAL statistics for a user -DROP TABLE IF EXISTS user_stats_historical; -CREATE TABLE user_stats_historical ( - user_id TEXT NOT NULL, - end_ts BIGINT NOT NULL, - bucket_size BIGINT NOT NULL, - - joined_rooms BIGINT NOT NULL, - - invites_sent BIGINT NOT NULL, - rooms_created BIGINT NOT NULL, - total_events BIGINT NOT NULL, - total_event_bytes BIGINT NOT NULL, - - PRIMARY KEY (user_id, end_ts) -); - --- We use this index to speed up deletion of ancient user stats. -CREATE INDEX user_stats_historical_end_ts ON user_stats_historical (end_ts); - - -CREATE TABLE room_stats_state ( - room_id TEXT NOT NULL, - name TEXT, - canonical_alias TEXT, - join_rules TEXT, - history_visibility TEXT, - encryption TEXT, - avatar TEXT, - guest_access TEXT, - is_federatable BOOLEAN, - topic TEXT -); - -CREATE UNIQUE INDEX room_stats_state_room ON room_stats_state(room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py deleted file mode 100644 index 1de8b54961..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py +++ /dev/null @@ -1,52 +0,0 @@ -import logging - -from synapse.storage.engines import PostgresEngine - -logger = logging.getLogger(__name__) - - -""" -This migration updates the user_filters table as follows: - - - drops any (user_id, filter_id) duplicates - - makes the columns NON-NULLable - - turns the index into a UNIQUE index -""" - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass - - -def run_create(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - select_clause = """ - SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json - FROM user_filters - """ - else: - select_clause = """ - SELECT * FROM user_filters GROUP BY user_id, filter_id - """ - sql = """ - DROP TABLE IF EXISTS user_filters_migration; - DROP INDEX IF EXISTS user_filters_unique; - CREATE TABLE user_filters_migration ( - user_id TEXT NOT NULL, - filter_id BIGINT NOT NULL, - filter_json BYTEA NOT NULL - ); - INSERT INTO user_filters_migration (user_id, filter_id, filter_json) - %s; - CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration - (user_id, filter_id); - DROP TABLE user_filters; - ALTER TABLE user_filters_migration RENAME TO user_filters; - """ % ( - select_clause, - ) - - if isinstance(database_engine, PostgresEngine): - cur.execute(sql) - else: - cur.executescript(sql) diff --git a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql deleted file mode 100644 index 91390c4527..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * a table which records mappings from external auth providers to mxids - */ -CREATE TABLE IF NOT EXISTS user_external_ids ( - auth_provider TEXT NOT NULL, - external_id TEXT NOT NULL, - user_id TEXT NOT NULL, - UNIQUE (auth_provider, external_id) -); diff --git a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql deleted file mode 100644 index 149f8be8b6..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2019 Matrix.org Foundation CIC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- this was apparently forgotten when the table was created back in delta 53. -CREATE INDEX users_in_public_rooms_r_idx ON users_in_public_rooms(room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql b/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql deleted file mode 100644 index aec06c8261..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Add background update to go and delete current state events for rooms the --- server is no longer in. --- --- this relies on the 'membership' column of current_state_events, so make sure --- that's populated first! -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ('delete_old_current_state_events', '{}', 'current_state_events_membership'); diff --git a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql b/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql deleted file mode 100644 index c3b6de2099..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Records whether the server thinks that the remote users cached device lists --- may be out of date (e.g. if we have received a to device message from a --- device we don't know about). -CREATE TABLE IF NOT EXISTS device_lists_remote_resync ( - user_id TEXT NOT NULL, - added_ts BIGINT NOT NULL -); - -CREATE UNIQUE INDEX device_lists_remote_resync_idx ON device_lists_remote_resync (user_id); -CREATE INDEX device_lists_remote_resync_ts_idx ON device_lists_remote_resync (added_ts); diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py deleted file mode 100644 index 63b5acdcf7..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py +++ /dev/null @@ -1,98 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# We create a new table called `local_current_membership` that stores the latest -# membership state of local users in rooms, which helps track leaves/bans/etc -# even if the server has left the room (and so has deleted the room from -# `current_state_events`). This will also include outstanding invites for local -# users for rooms the server isn't in. -# -# If the server isn't and hasn't been in the room then it will only include -# outsstanding invites, and not e.g. pre-emptive bans of local users. -# -# If the server later rejoins a room `local_current_membership` can simply be -# replaced with the new current state of the room (which results in the -# equivalent behaviour as if the server had remained in the room). - - -def run_upgrade(cur, database_engine, config, *args, **kwargs): - # We need to do the insert in `run_upgrade` section as we don't have access - # to `config` in `run_create`. - - # This upgrade may take a bit of time for large servers (e.g. one minute for - # matrix.org) but means we avoid a lots of book keeping required to do it as - # a background update. - - # We check if the `current_state_events.membership` is up to date by - # checking if the relevant background update has finished. If it has - # finished we can avoid doing a join against `room_memberships`, which - # speesd things up. - cur.execute( - """SELECT 1 FROM background_updates - WHERE update_name = 'current_state_events_membership' - """ - ) - current_state_membership_up_to_date = not bool(cur.fetchone()) - - # Cheekily drop and recreate indices, as that is faster. - cur.execute("DROP INDEX local_current_membership_idx") - cur.execute("DROP INDEX local_current_membership_room_idx") - - if current_state_membership_up_to_date: - sql = """ - INSERT INTO local_current_membership (room_id, user_id, event_id, membership) - SELECT c.room_id, state_key AS user_id, event_id, c.membership - FROM current_state_events AS c - WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key LIKE ? - """ - else: - # We can't rely on the membership column, so we need to join against - # `room_memberships`. - sql = """ - INSERT INTO local_current_membership (room_id, user_id, event_id, membership) - SELECT c.room_id, state_key AS user_id, event_id, r.membership - FROM current_state_events AS c - INNER JOIN room_memberships AS r USING (event_id) - WHERE type = 'm.room.member' AND state_key LIKE ? - """ - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("%:" + config.server_name,)) - - cur.execute( - "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" - ) - cur.execute( - "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" - ) - - -def run_create(cur, database_engine, *args, **kwargs): - cur.execute( - """ - CREATE TABLE local_current_membership ( - room_id TEXT NOT NULL, - user_id TEXT NOT NULL, - event_id TEXT NOT NULL, - membership TEXT NOT NULL - )""" - ) - - cur.execute( - "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" - ) - cur.execute( - "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" - ) diff --git a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql deleted file mode 100644 index 133d80af35..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- we no longer keep sent outbound device pokes in the db; clear them out --- so that we don't have to worry about them. --- --- This is a sequence scan, but it doesn't take too long. - -DELETE FROM device_lists_outbound_pokes WHERE sent; diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql deleted file mode 100644 index 352a66f5b0..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - --- We want to start storing the room version independently of --- `current_state_events` so that we can delete stale entries from it without --- losing the information. -ALTER TABLE rooms ADD COLUMN room_version TEXT; - - -INSERT into background_updates (update_name, progress_json) - VALUES ('add_rooms_room_version_column', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres deleted file mode 100644 index c601cff6de..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- when we first added the room_version column, it was populated via a background --- update. We now need it to be populated before synapse starts, so we populate --- any remaining rows with a NULL room version now. For servers which have completed --- the background update, this will be pretty quick. - --- the following query will set room_version to NULL if no create event is found for --- the room in current_state_events, and will set it to '1' if a create event with no --- room_version is found. - -UPDATE rooms SET room_version=( - SELECT COALESCE(json::json->'content'->>'room_version','1') - FROM current_state_events cse INNER JOIN event_json ej USING (event_id) - WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key='' -) WHERE rooms.room_version IS NULL; - --- we still allow the background update to complete: it has the useful side-effect of --- populating `rooms` with any missing rooms (based on the current_state_events table). - --- see also rooms_version_column_2.sql.sqlite which has a copy of the above query, using --- sqlite syntax for the json extraction. diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite deleted file mode 100644 index 335c6f2074..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- see rooms_version_column_2.sql.postgres for details of what's going on here. - -UPDATE rooms SET room_version=( - SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1') - FROM current_state_events cse INNER JOIN event_json ej USING (event_id) - WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key='' -) WHERE rooms.room_version IS NULL; diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres deleted file mode 100644 index 92aaadde0d..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- When we first added the room_version column to the rooms table, it was populated from --- the current_state_events table. However, there was an issue causing a background --- update to clean up the current_state_events table for rooms where the server is no --- longer participating, before that column could be populated. Therefore, some rooms had --- a NULL room_version. - --- The rooms_version_column_2.sql.* delta files were introduced to make the populating --- synchronous instead of running it in a background update, which fixed this issue. --- However, all of the instances of Synapse installed or updated in the meantime got --- their rooms table corrupted with NULL room_versions. - --- This query fishes out the room versions from the create event using the state_events --- table instead of the current_state_events one, as the former still have all of the --- create events. - -UPDATE rooms SET room_version=( - SELECT COALESCE(json::json->'content'->>'room_version','1') - FROM state_events se INNER JOIN event_json ej USING (event_id) - WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' - LIMIT 1 -) WHERE rooms.room_version IS NULL; - --- see also rooms_version_column_3.sql.sqlite which has a copy of the above query, using --- sqlite syntax for the json extraction. diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite deleted file mode 100644 index e19dab97cb..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- see rooms_version_column_3.sql.postgres for details of what's going on here. - -UPDATE rooms SET room_version=( - SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1') - FROM state_events se INNER JOIN event_json ej USING (event_id) - WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' - LIMIT 1 -) WHERE rooms.room_version IS NULL; diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql deleted file mode 100644 index fdc39e9ba5..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - /* for some reason, we have accumulated duplicate entries in - * device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less - * efficient. - */ - -INSERT INTO background_updates (ordering, update_name, progress_json) - VALUES (5800, 'remove_dup_outbound_pokes', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql deleted file mode 100644 index dcb593fc2d..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS ui_auth_sessions( - session_id TEXT NOT NULL, -- The session ID passed to the client. - creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds). - serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse. - clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client. - uri TEXT NOT NULL, -- The URI the UI authentication session is using. - method TEXT NOT NULL, -- The HTTP method the UI authentication session is using. - -- The clientdict, uri, and method make up an tuple that must be immutable - -- throughout the lifetime of the UI Auth session. - description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur. - UNIQUE (session_id) -); - -CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials( - session_id TEXT NOT NULL, -- The corresponding UI Auth session. - stage_type TEXT NOT NULL, -- The stage type. - result TEXT NOT NULL, -- The result of the stage verification, stored as JSON. - UNIQUE (session_id, stage_type), - FOREIGN KEY (session_id) - REFERENCES ui_auth_sessions (session_id) -); diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres deleted file mode 100644 index aa46eb0e10..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We keep the old table here to enable us to roll back. It doesn't matter --- that we have dropped all the data here. -TRUNCATE cache_invalidation_stream; - -CREATE TABLE cache_invalidation_stream_by_instance ( - stream_id BIGINT NOT NULL, - instance_name TEXT NOT NULL, - cache_func TEXT NOT NULL, - keys TEXT[], - invalidation_ts BIGINT -); - -CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id); - -CREATE SEQUENCE cache_invalidation_stream_seq; diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py deleted file mode 100644 index d353f2bcb3..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This migration rebuilds the device_lists_outbound_last_success table without duplicate -entries, and with a UNIQUE index. -""" - -import logging -from io import StringIO - -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.storage.prepare_database import execute_statements_from_stream -from synapse.storage.types import Cursor - -logger = logging.getLogger(__name__) - - -def run_upgrade(*args, **kwargs): - pass - - -def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): - # some instances might already have this index, in which case we can skip this - if isinstance(database_engine, PostgresEngine): - cur.execute( - """ - SELECT 1 FROM pg_class WHERE relkind = 'i' - AND relname = 'device_lists_outbound_last_success_unique_idx' - """ - ) - - if cur.rowcount: - logger.info( - "Unique index exists on device_lists_outbound_last_success: " - "skipping rebuild" - ) - return - - logger.info("Rebuilding device_lists_outbound_last_success with unique index") - execute_statements_from_stream(cur, StringIO(_rebuild_commands)) - - -# there might be duplicates, so the easiest way to achieve this is to create a new -# table with the right data, and renaming it into place - -_rebuild_commands = """ -DROP TABLE IF EXISTS device_lists_outbound_last_success_new; - -CREATE TABLE device_lists_outbound_last_success_new ( - destination TEXT NOT NULL, - user_id TEXT NOT NULL, - stream_id BIGINT NOT NULL -); - --- this took about 30 seconds on matrix.org's 16 million rows. -INSERT INTO device_lists_outbound_last_success_new - SELECT destination, user_id, MAX(stream_id) FROM device_lists_outbound_last_success - GROUP BY destination, user_id; - --- and this another 30 seconds. -CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx - ON device_lists_outbound_last_success_new (destination, user_id); - -DROP TABLE device_lists_outbound_last_success; - -ALTER TABLE device_lists_outbound_last_success_new - RENAME TO device_lists_outbound_last_success; -""" diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres deleted file mode 100644 index 597f2ffd3d..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- The local_media_repository should have files which do not get quarantined, --- e.g. files from sticker packs. -ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite deleted file mode 100644 index 69db89ac0e..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- The local_media_repository should have files which do not get quarantined, --- e.g. files from sticker packs. -ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0; diff --git a/synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql b/synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql deleted file mode 100644 index eb57203e46..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* -The version of synapse 1.16.0 on pypi incorrectly contained a migration which -added a table called 'local_rejections_stream'. This table is not used, and -we drop it here for anyone who was affected. -*/ - -DROP TABLE IF EXISTS local_rejections_stream; diff --git a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql deleted file mode 100644 index 1cc2633aad..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We need to store the stream positions by instance in a sharded config world. --- --- We default to master as we want the column to be NOT NULL and we correctly --- reset the instance name to match the config each time we start up. -ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master'; - -CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name); diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py deleted file mode 100644 index 2011f6bceb..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Adds a postgres SEQUENCE for generating guest user IDs. -""" - -from synapse.storage.data_stores.main.registration import ( - find_max_generated_user_id_localpart, -) -from synapse.storage.engines import PostgresEngine - - -def run_create(cur, database_engine, *args, **kwargs): - if not isinstance(database_engine, PostgresEngine): - return - - next_id = find_max_generated_user_id_localpart(cur) + 1 - cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/delta/58/12room_stats.sql b/synapse/storage/data_stores/main/schema/delta/58/12room_stats.sql deleted file mode 100644 index cade5dcca8..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/12room_stats.sql +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Recalculate the stats for all rooms after the fix to joined_members erroneously --- incrementing on per-room profile changes. - --- Note that the populate_stats_process_rooms background update is already set to --- run if you're upgrading from Synapse <1.0.0. - --- Additionally, if you've upgraded to v1.18.0 (which doesn't include this fix), --- this bg job runs, and then update to v1.19.0, you'd end up with only half of --- your rooms having room stats recalculated after this fix was in place. - --- So we've switched the old `populate_stats_process_rooms` background job to a --- no-op, and then kick off a bg job with a new name, but with the same --- functionality as the old one. This effectively restarts the background job --- from the beginning, without running it twice in a row, supporting both --- upgrade usecases. -INSERT INTO background_updates (update_name, progress_json) VALUES - ('populate_stats_process_rooms_2', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql deleted file mode 100644 index 531b532c73..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Store a boolean value in the events table for whether the event should be counted in --- the unread_count property of sync responses. -ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql deleted file mode 100644 index 883fcd10b2..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* We used to create tables called application_services and - * application_services_regex, but these are no longer used and are removed in - * delta 54. - */ - - -CREATE TABLE IF NOT EXISTS application_services_state( - as_id TEXT PRIMARY KEY, - state VARCHAR(5), - last_txn INTEGER -); - -CREATE TABLE IF NOT EXISTS application_services_txns( - as_id TEXT NOT NULL, - txn_id INTEGER NOT NULL, - event_ids TEXT NOT NULL, - UNIQUE(as_id, txn_id) -); - -CREATE INDEX application_services_txns_id ON application_services_txns ( - as_id -); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql deleted file mode 100644 index 10ce2aa7a0..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* We used to create tables called event_destinations and - * state_forward_extremities, but these are no longer used and are removed in - * delta 54. - */ - -CREATE TABLE IF NOT EXISTS event_forward_extremities( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (event_id, room_id) -); - -CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); -CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); - - -CREATE TABLE IF NOT EXISTS event_backward_extremities( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (event_id, room_id) -); - -CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); -CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); - - -CREATE TABLE IF NOT EXISTS event_edges( - event_id TEXT NOT NULL, - prev_event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - is_state BOOL NOT NULL, -- true if this is a prev_state edge rather than a regular - -- event dag edge. - UNIQUE (event_id, prev_event_id, room_id, is_state) -); - -CREATE INDEX ev_edges_id ON event_edges(event_id); -CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); - - -CREATE TABLE IF NOT EXISTS room_depth( - room_id TEXT NOT NULL, - min_depth INTEGER NOT NULL, - UNIQUE (room_id) -); - -CREATE INDEX room_depth_room ON room_depth(room_id); - -CREATE TABLE IF NOT EXISTS event_auth( - event_id TEXT NOT NULL, - auth_id TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (event_id, auth_id, room_id) -); - -CREATE INDEX evauth_edges_id ON event_auth(event_id); -CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql deleted file mode 100644 index 95826da431..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - /* We used to create tables called event_content_hashes and event_edge_hashes, - * but these are no longer used and are removed in delta 54. - */ - -CREATE TABLE IF NOT EXISTS event_reference_hashes ( - event_id TEXT, - algorithm TEXT, - hash bytea, - UNIQUE (event_id, algorithm) -); - -CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); - - -CREATE TABLE IF NOT EXISTS event_signatures ( - event_id TEXT, - signature_name TEXT, - key_id TEXT, - signature bytea, - UNIQUE (event_id, signature_name, key_id) -); - -CREATE INDEX event_signatures_id ON event_signatures(event_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql deleted file mode 100644 index a1a2aa8e5b..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* We used to create tables called room_hosts and feedback, - * but these are no longer used and are removed in delta 54. - */ - -CREATE TABLE IF NOT EXISTS events( - stream_ordering INTEGER PRIMARY KEY, - topological_ordering BIGINT NOT NULL, - event_id TEXT NOT NULL, - type TEXT NOT NULL, - room_id TEXT NOT NULL, - - -- 'content' used to be created NULLable, but as of delta 50 we drop that constraint. - -- the hack we use to drop the constraint doesn't work for an in-memory sqlite - -- database, which breaks the sytests. Hence, we no longer make it nullable. - content TEXT, - - unrecognized_keys TEXT, - processed BOOL NOT NULL, - outlier BOOL NOT NULL, - depth BIGINT DEFAULT 0 NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX events_stream_ordering ON events (stream_ordering); -CREATE INDEX events_topological_ordering ON events (topological_ordering); -CREATE INDEX events_order ON events (topological_ordering, stream_ordering); -CREATE INDEX events_room_id ON events (room_id); -CREATE INDEX events_order_room ON events ( - room_id, topological_ordering, stream_ordering -); - - -CREATE TABLE IF NOT EXISTS event_json( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - internal_metadata TEXT NOT NULL, - json TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX event_json_room_id ON event_json(room_id); - - -CREATE TABLE IF NOT EXISTS state_events( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - prev_state TEXT, - UNIQUE (event_id) -); - -CREATE INDEX state_events_room_id ON state_events (room_id); -CREATE INDEX state_events_type ON state_events (type); -CREATE INDEX state_events_state_key ON state_events (state_key); - - -CREATE TABLE IF NOT EXISTS current_state_events( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - UNIQUE (event_id), - UNIQUE (room_id, type, state_key) -); - -CREATE INDEX current_state_events_room_id ON current_state_events (room_id); -CREATE INDEX current_state_events_type ON current_state_events (type); -CREATE INDEX current_state_events_state_key ON current_state_events (state_key); - -CREATE TABLE IF NOT EXISTS room_memberships( - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - sender TEXT NOT NULL, - room_id TEXT NOT NULL, - membership TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX room_memberships_room_id ON room_memberships (room_id); -CREATE INDEX room_memberships_user_id ON room_memberships (user_id); - -CREATE TABLE IF NOT EXISTS topics( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - topic TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX topics_room_id ON topics(room_id); - -CREATE TABLE IF NOT EXISTS room_names( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - name TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX room_names_room_id ON room_names(room_id); - -CREATE TABLE IF NOT EXISTS rooms( - room_id TEXT PRIMARY KEY NOT NULL, - is_public BOOL, - creator TEXT -); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql deleted file mode 100644 index 11cdffdbb3..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- we used to create a table called server_tls_certificates, but this is no --- longer used, and is removed in delta 54. - -CREATE TABLE IF NOT EXISTS server_signature_keys( - server_name TEXT, -- Server name. - key_id TEXT, -- Key version. - from_server TEXT, -- Which key server the key was fetched form. - ts_added_ms BIGINT, -- When the key was added. - verify_key bytea, -- NACL verification key. - UNIQUE (server_name, key_id) -); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql deleted file mode 100644 index 8f3759bb2a..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS local_media_repository ( - media_id TEXT, -- The id used to refer to the media. - media_type TEXT, -- The MIME-type of the media. - media_length INTEGER, -- Length of the media in bytes. - created_ts BIGINT, -- When the content was uploaded in ms. - upload_name TEXT, -- The name the media was uploaded with. - user_id TEXT, -- The user who uploaded the file. - UNIQUE (media_id) -); - -CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( - media_id TEXT, -- The id used to refer to the media. - thumbnail_width INTEGER, -- The width of the thumbnail in pixels. - thumbnail_height INTEGER, -- The height of the thumbnail in pixels. - thumbnail_type TEXT, -- The MIME-type of the thumbnail. - thumbnail_method TEXT, -- The method used to make the thumbnail. - thumbnail_length INTEGER, -- The length of the thumbnail in bytes. - UNIQUE ( - media_id, thumbnail_width, thumbnail_height, thumbnail_type - ) -); - -CREATE INDEX local_media_repository_thumbnails_media_id - ON local_media_repository_thumbnails (media_id); - -CREATE TABLE IF NOT EXISTS remote_media_cache ( - media_origin TEXT, -- The remote HS the media came from. - media_id TEXT, -- The id used to refer to the media on that server. - media_type TEXT, -- The MIME-type of the media. - created_ts BIGINT, -- When the content was uploaded in ms. - upload_name TEXT, -- The name the media was uploaded with. - media_length INTEGER, -- Length of the media in bytes. - filesystem_id TEXT, -- The name used to store the media on disk. - UNIQUE (media_origin, media_id) -); - -CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( - media_origin TEXT, -- The remote HS the media came from. - media_id TEXT, -- The id used to refer to the media. - thumbnail_width INTEGER, -- The width of the thumbnail in pixels. - thumbnail_height INTEGER, -- The height of the thumbnail in pixels. - thumbnail_method TEXT, -- The method used to make the thumbnail - thumbnail_type TEXT, -- The MIME-type of the thumbnail. - thumbnail_length INTEGER, -- The length of the thumbnail in bytes. - filesystem_id TEXT, -- The name used to store the media on disk. - UNIQUE ( - media_origin, media_id, thumbnail_width, thumbnail_height, - thumbnail_type - ) -); - -CREATE INDEX remote_media_cache_thumbnails_media_id - ON remote_media_cache_thumbnails (media_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql deleted file mode 100644 index 01d2d8f833..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS presence( - user_id TEXT NOT NULL, - state VARCHAR(20), - status_msg TEXT, - mtime BIGINT, -- miliseconds since last state change - UNIQUE (user_id) -); - --- For each of /my/ users which possibly-remote users are allowed to see their --- presence state -CREATE TABLE IF NOT EXISTS presence_allow_inbound( - observed_user_id TEXT NOT NULL, - observer_user_id TEXT NOT NULL, -- a UserID, - UNIQUE (observed_user_id, observer_user_id) -); - --- We used to create a table called presence_list, but this is no longer used --- and is removed in delta 54. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql deleted file mode 100644 index c04f4747d9..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS profiles( - user_id TEXT NOT NULL, - displayname TEXT, - avatar_url TEXT, - UNIQUE(user_id) -); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql deleted file mode 100644 index e44465cf45..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS rejections( - event_id TEXT NOT NULL, - reason TEXT NOT NULL, - last_check TEXT NOT NULL, - UNIQUE (event_id) -); - --- Push notification endpoints that users have configured -CREATE TABLE IF NOT EXISTS pushers ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - access_token BIGINT DEFAULT NULL, - profile_tag VARCHAR(32) NOT NULL, - kind VARCHAR(8) NOT NULL, - app_id VARCHAR(64) NOT NULL, - app_display_name VARCHAR(64) NOT NULL, - device_display_name VARCHAR(128) NOT NULL, - pushkey bytea NOT NULL, - ts BIGINT NOT NULL, - lang VARCHAR(8), - data bytea, - last_token TEXT, - last_success BIGINT, - failing_since BIGINT, - UNIQUE (app_id, pushkey) -); - -CREATE TABLE IF NOT EXISTS push_rules ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - rule_id TEXT NOT NULL, - priority_class SMALLINT NOT NULL, - priority INTEGER NOT NULL DEFAULT 0, - conditions TEXT NOT NULL, - actions TEXT NOT NULL, - UNIQUE(user_name, rule_id) -); - -CREATE INDEX push_rules_user_name on push_rules (user_name); - -CREATE TABLE IF NOT EXISTS user_filters( - user_id TEXT, - filter_id BIGINT, - filter_json bytea -); - -CREATE INDEX user_filters_by_user_id_filter_id ON user_filters( - user_id, filter_id -); - -CREATE TABLE IF NOT EXISTS push_rules_enable ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - rule_id TEXT NOT NULL, - enabled SMALLINT, - UNIQUE(user_name, rule_id) -); - -CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql deleted file mode 100644 index 318f0d9aa5..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS redactions ( - event_id TEXT NOT NULL, - redacts TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX redactions_event_id ON redactions (event_id); -CREATE INDEX redactions_redacts ON redactions (redacts); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql deleted file mode 100644 index d47da3b12f..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS room_aliases( - room_alias TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (room_alias) -); - -CREATE INDEX room_aliases_id ON room_aliases(room_id); - -CREATE TABLE IF NOT EXISTS room_alias_servers( - room_alias TEXT NOT NULL, - server TEXT NOT NULL -); - -CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql deleted file mode 100644 index 96391a8f0e..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS state_groups( - id BIGINT PRIMARY KEY, - room_id TEXT NOT NULL, - event_id TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS state_groups_state( - state_group BIGINT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - event_id TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS event_to_state_groups( - event_id TEXT NOT NULL, - state_group BIGINT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX state_groups_id ON state_groups(id); - -CREATE INDEX state_groups_state_id ON state_groups_state(state_group); -CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key); -CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql deleted file mode 100644 index 17e67bedac..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ --- Stores what transaction ids we have received and what our response was -CREATE TABLE IF NOT EXISTS received_transactions( - transaction_id TEXT, - origin TEXT, - ts BIGINT, - response_code INTEGER, - response_json bytea, - has_been_referenced smallint default 0, -- Whether thishas been referenced by a prev_tx - UNIQUE (transaction_id, origin) -); - -CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0; - --- For sent transactions only. -CREATE TABLE IF NOT EXISTS transaction_id_to_pdu( - transaction_id INTEGER, - destination TEXT, - pdu_id TEXT, - pdu_origin TEXT, - UNIQUE (transaction_id, destination) -); - -CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); - --- To track destination health -CREATE TABLE IF NOT EXISTS destinations( - destination TEXT PRIMARY KEY, - retry_last_ts BIGINT, - retry_interval INTEGER -); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql deleted file mode 100644 index f013aa8b18..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS users( - name TEXT, - password_hash TEXT, - creation_ts BIGINT, - admin SMALLINT DEFAULT 0 NOT NULL, - UNIQUE(name) -); - -CREATE TABLE IF NOT EXISTS access_tokens( - id BIGINT PRIMARY KEY, - user_id TEXT NOT NULL, - device_id TEXT, - token TEXT NOT NULL, - last_used BIGINT, - UNIQUE(token) -); - -CREATE TABLE IF NOT EXISTS user_ips ( - user_id TEXT NOT NULL, - access_token TEXT NOT NULL, - device_id TEXT, - ip TEXT NOT NULL, - user_agent TEXT NOT NULL, - last_seen BIGINT NOT NULL -); - -CREATE INDEX user_ips_user ON user_ips(user_id); -CREATE INDEX user_ips_user_ip ON user_ips(user_id, access_token, ip); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres deleted file mode 100644 index 889a9a0ce4..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres +++ /dev/null @@ -1,1983 +0,0 @@ - - - - - -CREATE TABLE access_tokens ( - id bigint NOT NULL, - user_id text NOT NULL, - device_id text, - token text NOT NULL, - last_used bigint -); - - - -CREATE TABLE account_data ( - user_id text NOT NULL, - account_data_type text NOT NULL, - stream_id bigint NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE account_data_max_stream_id ( - lock character(1) DEFAULT 'X'::bpchar NOT NULL, - stream_id bigint NOT NULL, - CONSTRAINT private_user_data_max_stream_id_lock_check CHECK ((lock = 'X'::bpchar)) -); - - - -CREATE TABLE account_validity ( - user_id text NOT NULL, - expiration_ts_ms bigint NOT NULL, - email_sent boolean NOT NULL, - renewal_token text -); - - - -CREATE TABLE application_services_state ( - as_id text NOT NULL, - state character varying(5), - last_txn integer -); - - - -CREATE TABLE application_services_txns ( - as_id text NOT NULL, - txn_id integer NOT NULL, - event_ids text NOT NULL -); - - - -CREATE TABLE appservice_room_list ( - appservice_id text NOT NULL, - network_id text NOT NULL, - room_id text NOT NULL -); - - - -CREATE TABLE appservice_stream_position ( - lock character(1) DEFAULT 'X'::bpchar NOT NULL, - stream_ordering bigint, - CONSTRAINT appservice_stream_position_lock_check CHECK ((lock = 'X'::bpchar)) -); - - -CREATE TABLE blocked_rooms ( - room_id text NOT NULL, - user_id text NOT NULL -); - - - -CREATE TABLE cache_invalidation_stream ( - stream_id bigint, - cache_func text, - keys text[], - invalidation_ts bigint -); - - - -CREATE TABLE current_state_delta_stream ( - stream_id bigint NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - event_id text, - prev_event_id text -); - - - -CREATE TABLE current_state_events ( - event_id text NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL -); - - - -CREATE TABLE deleted_pushers ( - stream_id bigint NOT NULL, - app_id text NOT NULL, - pushkey text NOT NULL, - user_id text NOT NULL -); - - - -CREATE TABLE destinations ( - destination text NOT NULL, - retry_last_ts bigint, - retry_interval integer -); - - - -CREATE TABLE device_federation_inbox ( - origin text NOT NULL, - message_id text NOT NULL, - received_ts bigint NOT NULL -); - - - -CREATE TABLE device_federation_outbox ( - destination text NOT NULL, - stream_id bigint NOT NULL, - queued_ts bigint NOT NULL, - messages_json text NOT NULL -); - - - -CREATE TABLE device_inbox ( - user_id text NOT NULL, - device_id text NOT NULL, - stream_id bigint NOT NULL, - message_json text NOT NULL -); - - - -CREATE TABLE device_lists_outbound_last_success ( - destination text NOT NULL, - user_id text NOT NULL, - stream_id bigint NOT NULL -); - - - -CREATE TABLE device_lists_outbound_pokes ( - destination text NOT NULL, - stream_id bigint NOT NULL, - user_id text NOT NULL, - device_id text NOT NULL, - sent boolean NOT NULL, - ts bigint NOT NULL -); - - - -CREATE TABLE device_lists_remote_cache ( - user_id text NOT NULL, - device_id text NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE device_lists_remote_extremeties ( - user_id text NOT NULL, - stream_id text NOT NULL -); - - - -CREATE TABLE device_lists_stream ( - stream_id bigint NOT NULL, - user_id text NOT NULL, - device_id text NOT NULL -); - - - -CREATE TABLE device_max_stream_id ( - stream_id bigint NOT NULL -); - - - -CREATE TABLE devices ( - user_id text NOT NULL, - device_id text NOT NULL, - display_name text -); - - - -CREATE TABLE e2e_device_keys_json ( - user_id text NOT NULL, - device_id text NOT NULL, - ts_added_ms bigint NOT NULL, - key_json text NOT NULL -); - - - -CREATE TABLE e2e_one_time_keys_json ( - user_id text NOT NULL, - device_id text NOT NULL, - algorithm text NOT NULL, - key_id text NOT NULL, - ts_added_ms bigint NOT NULL, - key_json text NOT NULL -); - - - -CREATE TABLE e2e_room_keys ( - user_id text NOT NULL, - room_id text NOT NULL, - session_id text NOT NULL, - version bigint NOT NULL, - first_message_index integer, - forwarded_count integer, - is_verified boolean, - session_data text NOT NULL -); - - - -CREATE TABLE e2e_room_keys_versions ( - user_id text NOT NULL, - version bigint NOT NULL, - algorithm text NOT NULL, - auth_data text NOT NULL, - deleted smallint DEFAULT 0 NOT NULL -); - - - -CREATE TABLE erased_users ( - user_id text NOT NULL -); - - - -CREATE TABLE event_auth ( - event_id text NOT NULL, - auth_id text NOT NULL, - room_id text NOT NULL -); - - - -CREATE TABLE event_backward_extremities ( - event_id text NOT NULL, - room_id text NOT NULL -); - - - -CREATE TABLE event_edges ( - event_id text NOT NULL, - prev_event_id text NOT NULL, - room_id text NOT NULL, - is_state boolean NOT NULL -); - - - -CREATE TABLE event_forward_extremities ( - event_id text NOT NULL, - room_id text NOT NULL -); - - - -CREATE TABLE event_json ( - event_id text NOT NULL, - room_id text NOT NULL, - internal_metadata text NOT NULL, - json text NOT NULL, - format_version integer -); - - - -CREATE TABLE event_push_actions ( - room_id text NOT NULL, - event_id text NOT NULL, - user_id text NOT NULL, - profile_tag character varying(32), - actions text NOT NULL, - topological_ordering bigint, - stream_ordering bigint, - notif smallint, - highlight smallint -); - - - -CREATE TABLE event_push_actions_staging ( - event_id text NOT NULL, - user_id text NOT NULL, - actions text NOT NULL, - notif smallint NOT NULL, - highlight smallint NOT NULL -); - - - -CREATE TABLE event_push_summary ( - user_id text NOT NULL, - room_id text NOT NULL, - notif_count bigint NOT NULL, - stream_ordering bigint NOT NULL -); - - - -CREATE TABLE event_push_summary_stream_ordering ( - lock character(1) DEFAULT 'X'::bpchar NOT NULL, - stream_ordering bigint NOT NULL, - CONSTRAINT event_push_summary_stream_ordering_lock_check CHECK ((lock = 'X'::bpchar)) -); - - - -CREATE TABLE event_reference_hashes ( - event_id text, - algorithm text, - hash bytea -); - - - -CREATE TABLE event_relations ( - event_id text NOT NULL, - relates_to_id text NOT NULL, - relation_type text NOT NULL, - aggregation_key text -); - - - -CREATE TABLE event_reports ( - id bigint NOT NULL, - received_ts bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL, - user_id text NOT NULL, - reason text, - content text -); - - - -CREATE TABLE event_search ( - event_id text, - room_id text, - sender text, - key text, - vector tsvector, - origin_server_ts bigint, - stream_ordering bigint -); - - - -CREATE TABLE event_to_state_groups ( - event_id text NOT NULL, - state_group bigint NOT NULL -); - - - -CREATE TABLE events ( - stream_ordering integer NOT NULL, - topological_ordering bigint NOT NULL, - event_id text NOT NULL, - type text NOT NULL, - room_id text NOT NULL, - content text, - unrecognized_keys text, - processed boolean NOT NULL, - outlier boolean NOT NULL, - depth bigint DEFAULT 0 NOT NULL, - origin_server_ts bigint, - received_ts bigint, - sender text, - contains_url boolean -); - - - -CREATE TABLE ex_outlier_stream ( - event_stream_ordering bigint NOT NULL, - event_id text NOT NULL, - state_group bigint NOT NULL -); - - - -CREATE TABLE federation_stream_position ( - type text NOT NULL, - stream_id integer NOT NULL -); - - - -CREATE TABLE group_attestations_remote ( - group_id text NOT NULL, - user_id text NOT NULL, - valid_until_ms bigint NOT NULL, - attestation_json text NOT NULL -); - - - -CREATE TABLE group_attestations_renewals ( - group_id text NOT NULL, - user_id text NOT NULL, - valid_until_ms bigint NOT NULL -); - - - -CREATE TABLE group_invites ( - group_id text NOT NULL, - user_id text NOT NULL -); - - - -CREATE TABLE group_roles ( - group_id text NOT NULL, - role_id text NOT NULL, - profile text NOT NULL, - is_public boolean NOT NULL -); - - - -CREATE TABLE group_room_categories ( - group_id text NOT NULL, - category_id text NOT NULL, - profile text NOT NULL, - is_public boolean NOT NULL -); - - - -CREATE TABLE group_rooms ( - group_id text NOT NULL, - room_id text NOT NULL, - is_public boolean NOT NULL -); - - - -CREATE TABLE group_summary_roles ( - group_id text NOT NULL, - role_id text NOT NULL, - role_order bigint NOT NULL, - CONSTRAINT group_summary_roles_role_order_check CHECK ((role_order > 0)) -); - - - -CREATE TABLE group_summary_room_categories ( - group_id text NOT NULL, - category_id text NOT NULL, - cat_order bigint NOT NULL, - CONSTRAINT group_summary_room_categories_cat_order_check CHECK ((cat_order > 0)) -); - - - -CREATE TABLE group_summary_rooms ( - group_id text NOT NULL, - room_id text NOT NULL, - category_id text NOT NULL, - room_order bigint NOT NULL, - is_public boolean NOT NULL, - CONSTRAINT group_summary_rooms_room_order_check CHECK ((room_order > 0)) -); - - - -CREATE TABLE group_summary_users ( - group_id text NOT NULL, - user_id text NOT NULL, - role_id text NOT NULL, - user_order bigint NOT NULL, - is_public boolean NOT NULL -); - - - -CREATE TABLE group_users ( - group_id text NOT NULL, - user_id text NOT NULL, - is_admin boolean NOT NULL, - is_public boolean NOT NULL -); - - - -CREATE TABLE groups ( - group_id text NOT NULL, - name text, - avatar_url text, - short_description text, - long_description text, - is_public boolean NOT NULL, - join_policy text DEFAULT 'invite'::text NOT NULL -); - - - -CREATE TABLE guest_access ( - event_id text NOT NULL, - room_id text NOT NULL, - guest_access text NOT NULL -); - - - -CREATE TABLE history_visibility ( - event_id text NOT NULL, - room_id text NOT NULL, - history_visibility text NOT NULL -); - - - -CREATE TABLE local_group_membership ( - group_id text NOT NULL, - user_id text NOT NULL, - is_admin boolean NOT NULL, - membership text NOT NULL, - is_publicised boolean NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE local_group_updates ( - stream_id bigint NOT NULL, - group_id text NOT NULL, - user_id text NOT NULL, - type text NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE local_invites ( - stream_id bigint NOT NULL, - inviter text NOT NULL, - invitee text NOT NULL, - event_id text NOT NULL, - room_id text NOT NULL, - locally_rejected text, - replaced_by text -); - - - -CREATE TABLE local_media_repository ( - media_id text, - media_type text, - media_length integer, - created_ts bigint, - upload_name text, - user_id text, - quarantined_by text, - url_cache text, - last_access_ts bigint -); - - - -CREATE TABLE local_media_repository_thumbnails ( - media_id text, - thumbnail_width integer, - thumbnail_height integer, - thumbnail_type text, - thumbnail_method text, - thumbnail_length integer -); - - - -CREATE TABLE local_media_repository_url_cache ( - url text, - response_code integer, - etag text, - expires_ts bigint, - og text, - media_id text, - download_ts bigint -); - - - -CREATE TABLE monthly_active_users ( - user_id text NOT NULL, - "timestamp" bigint NOT NULL -); - - - -CREATE TABLE open_id_tokens ( - token text NOT NULL, - ts_valid_until_ms bigint NOT NULL, - user_id text NOT NULL -); - - - -CREATE TABLE presence ( - user_id text NOT NULL, - state character varying(20), - status_msg text, - mtime bigint -); - - - -CREATE TABLE presence_allow_inbound ( - observed_user_id text NOT NULL, - observer_user_id text NOT NULL -); - - - -CREATE TABLE presence_stream ( - stream_id bigint, - user_id text, - state text, - last_active_ts bigint, - last_federation_update_ts bigint, - last_user_sync_ts bigint, - status_msg text, - currently_active boolean -); - - - -CREATE TABLE profiles ( - user_id text NOT NULL, - displayname text, - avatar_url text -); - - - -CREATE TABLE public_room_list_stream ( - stream_id bigint NOT NULL, - room_id text NOT NULL, - visibility boolean NOT NULL, - appservice_id text, - network_id text -); - - - -CREATE TABLE push_rules ( - id bigint NOT NULL, - user_name text NOT NULL, - rule_id text NOT NULL, - priority_class smallint NOT NULL, - priority integer DEFAULT 0 NOT NULL, - conditions text NOT NULL, - actions text NOT NULL -); - - - -CREATE TABLE push_rules_enable ( - id bigint NOT NULL, - user_name text NOT NULL, - rule_id text NOT NULL, - enabled smallint -); - - - -CREATE TABLE push_rules_stream ( - stream_id bigint NOT NULL, - event_stream_ordering bigint NOT NULL, - user_id text NOT NULL, - rule_id text NOT NULL, - op text NOT NULL, - priority_class smallint, - priority integer, - conditions text, - actions text -); - - - -CREATE TABLE pusher_throttle ( - pusher bigint NOT NULL, - room_id text NOT NULL, - last_sent_ts bigint, - throttle_ms bigint -); - - - -CREATE TABLE pushers ( - id bigint NOT NULL, - user_name text NOT NULL, - access_token bigint, - profile_tag text NOT NULL, - kind text NOT NULL, - app_id text NOT NULL, - app_display_name text NOT NULL, - device_display_name text NOT NULL, - pushkey text NOT NULL, - ts bigint NOT NULL, - lang text, - data text, - last_stream_ordering integer, - last_success bigint, - failing_since bigint -); - - - -CREATE TABLE ratelimit_override ( - user_id text NOT NULL, - messages_per_second bigint, - burst_count bigint -); - - - -CREATE TABLE receipts_graph ( - room_id text NOT NULL, - receipt_type text NOT NULL, - user_id text NOT NULL, - event_ids text NOT NULL, - data text NOT NULL -); - - - -CREATE TABLE receipts_linearized ( - stream_id bigint NOT NULL, - room_id text NOT NULL, - receipt_type text NOT NULL, - user_id text NOT NULL, - event_id text NOT NULL, - data text NOT NULL -); - - - -CREATE TABLE received_transactions ( - transaction_id text, - origin text, - ts bigint, - response_code integer, - response_json bytea, - has_been_referenced smallint DEFAULT 0 -); - - - -CREATE TABLE redactions ( - event_id text NOT NULL, - redacts text NOT NULL -); - - - -CREATE TABLE rejections ( - event_id text NOT NULL, - reason text NOT NULL, - last_check text NOT NULL -); - - - -CREATE TABLE remote_media_cache ( - media_origin text, - media_id text, - media_type text, - created_ts bigint, - upload_name text, - media_length integer, - filesystem_id text, - last_access_ts bigint, - quarantined_by text -); - - - -CREATE TABLE remote_media_cache_thumbnails ( - media_origin text, - media_id text, - thumbnail_width integer, - thumbnail_height integer, - thumbnail_method text, - thumbnail_type text, - thumbnail_length integer, - filesystem_id text -); - - - -CREATE TABLE remote_profile_cache ( - user_id text NOT NULL, - displayname text, - avatar_url text, - last_check bigint NOT NULL -); - - - -CREATE TABLE room_account_data ( - user_id text NOT NULL, - room_id text NOT NULL, - account_data_type text NOT NULL, - stream_id bigint NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE room_alias_servers ( - room_alias text NOT NULL, - server text NOT NULL -); - - - -CREATE TABLE room_aliases ( - room_alias text NOT NULL, - room_id text NOT NULL, - creator text -); - - - -CREATE TABLE room_depth ( - room_id text NOT NULL, - min_depth integer NOT NULL -); - - - -CREATE TABLE room_memberships ( - event_id text NOT NULL, - user_id text NOT NULL, - sender text NOT NULL, - room_id text NOT NULL, - membership text NOT NULL, - forgotten integer DEFAULT 0, - display_name text, - avatar_url text -); - - - -CREATE TABLE room_names ( - event_id text NOT NULL, - room_id text NOT NULL, - name text NOT NULL -); - - - -CREATE TABLE room_state ( - room_id text NOT NULL, - join_rules text, - history_visibility text, - encryption text, - name text, - topic text, - avatar text, - canonical_alias text -); - - - -CREATE TABLE room_stats ( - room_id text NOT NULL, - ts bigint NOT NULL, - bucket_size integer NOT NULL, - current_state_events integer NOT NULL, - joined_members integer NOT NULL, - invited_members integer NOT NULL, - left_members integer NOT NULL, - banned_members integer NOT NULL, - state_events integer NOT NULL -); - - - -CREATE TABLE room_stats_earliest_token ( - room_id text NOT NULL, - token bigint NOT NULL -); - - - -CREATE TABLE room_tags ( - user_id text NOT NULL, - room_id text NOT NULL, - tag text NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE room_tags_revisions ( - user_id text NOT NULL, - room_id text NOT NULL, - stream_id bigint NOT NULL -); - - - -CREATE TABLE rooms ( - room_id text NOT NULL, - is_public boolean, - creator text -); - - - -CREATE TABLE server_keys_json ( - server_name text NOT NULL, - key_id text NOT NULL, - from_server text NOT NULL, - ts_added_ms bigint NOT NULL, - ts_valid_until_ms bigint NOT NULL, - key_json bytea NOT NULL -); - - - -CREATE TABLE server_signature_keys ( - server_name text, - key_id text, - from_server text, - ts_added_ms bigint, - verify_key bytea, - ts_valid_until_ms bigint -); - - - -CREATE TABLE state_events ( - event_id text NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - prev_state text -); - - - -CREATE TABLE stats_stream_pos ( - lock character(1) DEFAULT 'X'::bpchar NOT NULL, - stream_id bigint, - CONSTRAINT stats_stream_pos_lock_check CHECK ((lock = 'X'::bpchar)) -); - - - -CREATE TABLE stream_ordering_to_exterm ( - stream_ordering bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL -); - - - -CREATE TABLE threepid_guest_access_tokens ( - medium text, - address text, - guest_access_token text, - first_inviter text -); - - - -CREATE TABLE topics ( - event_id text NOT NULL, - room_id text NOT NULL, - topic text NOT NULL -); - - - -CREATE TABLE user_daily_visits ( - user_id text NOT NULL, - device_id text, - "timestamp" bigint NOT NULL -); - - - -CREATE TABLE user_directory ( - user_id text NOT NULL, - room_id text, - display_name text, - avatar_url text -); - - - -CREATE TABLE user_directory_search ( - user_id text NOT NULL, - vector tsvector -); - - - -CREATE TABLE user_directory_stream_pos ( - lock character(1) DEFAULT 'X'::bpchar NOT NULL, - stream_id bigint, - CONSTRAINT user_directory_stream_pos_lock_check CHECK ((lock = 'X'::bpchar)) -); - - - -CREATE TABLE user_filters ( - user_id text, - filter_id bigint, - filter_json bytea -); - - - -CREATE TABLE user_ips ( - user_id text NOT NULL, - access_token text NOT NULL, - device_id text, - ip text NOT NULL, - user_agent text NOT NULL, - last_seen bigint NOT NULL -); - - - -CREATE TABLE user_stats ( - user_id text NOT NULL, - ts bigint NOT NULL, - bucket_size integer NOT NULL, - public_rooms integer NOT NULL, - private_rooms integer NOT NULL -); - - - -CREATE TABLE user_threepid_id_server ( - user_id text NOT NULL, - medium text NOT NULL, - address text NOT NULL, - id_server text NOT NULL -); - - - -CREATE TABLE user_threepids ( - user_id text NOT NULL, - medium text NOT NULL, - address text NOT NULL, - validated_at bigint NOT NULL, - added_at bigint NOT NULL -); - - - -CREATE TABLE users ( - name text, - password_hash text, - creation_ts bigint, - admin smallint DEFAULT 0 NOT NULL, - upgrade_ts bigint, - is_guest smallint DEFAULT 0 NOT NULL, - appservice_id text, - consent_version text, - consent_server_notice_sent text, - user_type text -); - - - -CREATE TABLE users_in_public_rooms ( - user_id text NOT NULL, - room_id text NOT NULL -); - - - -CREATE TABLE users_pending_deactivation ( - user_id text NOT NULL -); - - - -CREATE TABLE users_who_share_private_rooms ( - user_id text NOT NULL, - other_user_id text NOT NULL, - room_id text NOT NULL -); - - - -ALTER TABLE ONLY access_tokens - ADD CONSTRAINT access_tokens_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY access_tokens - ADD CONSTRAINT access_tokens_token_key UNIQUE (token); - - - -ALTER TABLE ONLY account_data - ADD CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type); - - - -ALTER TABLE ONLY account_validity - ADD CONSTRAINT account_validity_pkey PRIMARY KEY (user_id); - - - -ALTER TABLE ONLY application_services_state - ADD CONSTRAINT application_services_state_pkey PRIMARY KEY (as_id); - - - -ALTER TABLE ONLY application_services_txns - ADD CONSTRAINT application_services_txns_as_id_txn_id_key UNIQUE (as_id, txn_id); - - - -ALTER TABLE ONLY appservice_stream_position - ADD CONSTRAINT appservice_stream_position_lock_key UNIQUE (lock); - - - -ALTER TABLE ONLY current_state_events - ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY current_state_events - ADD CONSTRAINT current_state_events_room_id_type_state_key_key UNIQUE (room_id, type, state_key); - - - -ALTER TABLE ONLY destinations - ADD CONSTRAINT destinations_pkey PRIMARY KEY (destination); - - - -ALTER TABLE ONLY devices - ADD CONSTRAINT device_uniqueness UNIQUE (user_id, device_id); - - - -ALTER TABLE ONLY e2e_device_keys_json - ADD CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id); - - - -ALTER TABLE ONLY e2e_one_time_keys_json - ADD CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id); - - - -ALTER TABLE ONLY event_backward_extremities - ADD CONSTRAINT event_backward_extremities_event_id_room_id_key UNIQUE (event_id, room_id); - - - -ALTER TABLE ONLY event_edges - ADD CONSTRAINT event_edges_event_id_prev_event_id_room_id_is_state_key UNIQUE (event_id, prev_event_id, room_id, is_state); - - - -ALTER TABLE ONLY event_forward_extremities - ADD CONSTRAINT event_forward_extremities_event_id_room_id_key UNIQUE (event_id, room_id); - - - -ALTER TABLE ONLY event_push_actions - ADD CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag); - - - -ALTER TABLE ONLY event_json - ADD CONSTRAINT event_json_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY event_push_summary_stream_ordering - ADD CONSTRAINT event_push_summary_stream_ordering_lock_key UNIQUE (lock); - - - -ALTER TABLE ONLY event_reference_hashes - ADD CONSTRAINT event_reference_hashes_event_id_algorithm_key UNIQUE (event_id, algorithm); - - - -ALTER TABLE ONLY event_reports - ADD CONSTRAINT event_reports_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY event_to_state_groups - ADD CONSTRAINT event_to_state_groups_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY events - ADD CONSTRAINT events_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY events - ADD CONSTRAINT events_pkey PRIMARY KEY (stream_ordering); - - - -ALTER TABLE ONLY ex_outlier_stream - ADD CONSTRAINT ex_outlier_stream_pkey PRIMARY KEY (event_stream_ordering); - - - -ALTER TABLE ONLY group_roles - ADD CONSTRAINT group_roles_group_id_role_id_key UNIQUE (group_id, role_id); - - - -ALTER TABLE ONLY group_room_categories - ADD CONSTRAINT group_room_categories_group_id_category_id_key UNIQUE (group_id, category_id); - - - -ALTER TABLE ONLY group_summary_roles - ADD CONSTRAINT group_summary_roles_group_id_role_id_role_order_key UNIQUE (group_id, role_id, role_order); - - - -ALTER TABLE ONLY group_summary_room_categories - ADD CONSTRAINT group_summary_room_categories_group_id_category_id_cat_orde_key UNIQUE (group_id, category_id, cat_order); - - - -ALTER TABLE ONLY group_summary_rooms - ADD CONSTRAINT group_summary_rooms_group_id_category_id_room_id_room_order_key UNIQUE (group_id, category_id, room_id, room_order); - - - -ALTER TABLE ONLY guest_access - ADD CONSTRAINT guest_access_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY history_visibility - ADD CONSTRAINT history_visibility_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY local_media_repository - ADD CONSTRAINT local_media_repository_media_id_key UNIQUE (media_id); - - - -ALTER TABLE ONLY local_media_repository_thumbnails - ADD CONSTRAINT local_media_repository_thumbn_media_id_thumbnail_width_thum_key UNIQUE (media_id, thumbnail_width, thumbnail_height, thumbnail_type); - - - -ALTER TABLE ONLY user_threepids - ADD CONSTRAINT medium_address UNIQUE (medium, address); - - - -ALTER TABLE ONLY open_id_tokens - ADD CONSTRAINT open_id_tokens_pkey PRIMARY KEY (token); - - - -ALTER TABLE ONLY presence_allow_inbound - ADD CONSTRAINT presence_allow_inbound_observed_user_id_observer_user_id_key UNIQUE (observed_user_id, observer_user_id); - - - -ALTER TABLE ONLY presence - ADD CONSTRAINT presence_user_id_key UNIQUE (user_id); - - - -ALTER TABLE ONLY account_data_max_stream_id - ADD CONSTRAINT private_user_data_max_stream_id_lock_key UNIQUE (lock); - - - -ALTER TABLE ONLY profiles - ADD CONSTRAINT profiles_user_id_key UNIQUE (user_id); - - - -ALTER TABLE ONLY push_rules_enable - ADD CONSTRAINT push_rules_enable_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY push_rules_enable - ADD CONSTRAINT push_rules_enable_user_name_rule_id_key UNIQUE (user_name, rule_id); - - - -ALTER TABLE ONLY push_rules - ADD CONSTRAINT push_rules_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY push_rules - ADD CONSTRAINT push_rules_user_name_rule_id_key UNIQUE (user_name, rule_id); - - - -ALTER TABLE ONLY pusher_throttle - ADD CONSTRAINT pusher_throttle_pkey PRIMARY KEY (pusher, room_id); - - - -ALTER TABLE ONLY pushers - ADD CONSTRAINT pushers2_app_id_pushkey_user_name_key UNIQUE (app_id, pushkey, user_name); - - - -ALTER TABLE ONLY pushers - ADD CONSTRAINT pushers2_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY receipts_graph - ADD CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id); - - - -ALTER TABLE ONLY receipts_linearized - ADD CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id); - - - -ALTER TABLE ONLY received_transactions - ADD CONSTRAINT received_transactions_transaction_id_origin_key UNIQUE (transaction_id, origin); - - - -ALTER TABLE ONLY redactions - ADD CONSTRAINT redactions_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY rejections - ADD CONSTRAINT rejections_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY remote_media_cache - ADD CONSTRAINT remote_media_cache_media_origin_media_id_key UNIQUE (media_origin, media_id); - - - -ALTER TABLE ONLY remote_media_cache_thumbnails - ADD CONSTRAINT remote_media_cache_thumbnails_media_origin_media_id_thumbna_key UNIQUE (media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type); - - - -ALTER TABLE ONLY room_account_data - ADD CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type); - - - -ALTER TABLE ONLY room_aliases - ADD CONSTRAINT room_aliases_room_alias_key UNIQUE (room_alias); - - - -ALTER TABLE ONLY room_depth - ADD CONSTRAINT room_depth_room_id_key UNIQUE (room_id); - - - -ALTER TABLE ONLY room_memberships - ADD CONSTRAINT room_memberships_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY room_names - ADD CONSTRAINT room_names_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY room_tags_revisions - ADD CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id); - - - -ALTER TABLE ONLY room_tags - ADD CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag); - - - -ALTER TABLE ONLY rooms - ADD CONSTRAINT rooms_pkey PRIMARY KEY (room_id); - - - -ALTER TABLE ONLY server_keys_json - ADD CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server); - - - -ALTER TABLE ONLY server_signature_keys - ADD CONSTRAINT server_signature_keys_server_name_key_id_key UNIQUE (server_name, key_id); - - - -ALTER TABLE ONLY state_events - ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); - - -ALTER TABLE ONLY stats_stream_pos - ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock); - - - -ALTER TABLE ONLY topics - ADD CONSTRAINT topics_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY user_directory_stream_pos - ADD CONSTRAINT user_directory_stream_pos_lock_key UNIQUE (lock); - - - -ALTER TABLE ONLY users - ADD CONSTRAINT users_name_key UNIQUE (name); - - - -CREATE INDEX access_tokens_device_id ON access_tokens USING btree (user_id, device_id); - - - -CREATE INDEX account_data_stream_id ON account_data USING btree (user_id, stream_id); - - - -CREATE INDEX application_services_txns_id ON application_services_txns USING btree (as_id); - - - -CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list USING btree (appservice_id, network_id, room_id); - - - -CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms USING btree (room_id); - - - -CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream USING btree (stream_id); - - - -CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream USING btree (stream_id); - - - -CREATE INDEX current_state_events_member_index ON current_state_events USING btree (state_key) WHERE (type = 'm.room.member'::text); - - - -CREATE INDEX deleted_pushers_stream_id ON deleted_pushers USING btree (stream_id); - - - -CREATE INDEX device_federation_inbox_sender_id ON device_federation_inbox USING btree (origin, message_id); - - - -CREATE INDEX device_federation_outbox_destination_id ON device_federation_outbox USING btree (destination, stream_id); - - - -CREATE INDEX device_federation_outbox_id ON device_federation_outbox USING btree (stream_id); - - - -CREATE INDEX device_inbox_stream_id_user_id ON device_inbox USING btree (stream_id, user_id); - - - -CREATE INDEX device_inbox_user_stream_id ON device_inbox USING btree (user_id, device_id, stream_id); - - - -CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success USING btree (destination, user_id, stream_id); - - - -CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes USING btree (destination, stream_id); - - - -CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes USING btree (stream_id); - - - -CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes USING btree (destination, user_id); - - - -CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache USING btree (user_id, device_id); - - - -CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties USING btree (user_id); - - - -CREATE INDEX device_lists_stream_id ON device_lists_stream USING btree (stream_id, user_id); - - - -CREATE INDEX device_lists_stream_user_id ON device_lists_stream USING btree (user_id, device_id); - - - -CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys USING btree (user_id, room_id, session_id); - - - -CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions USING btree (user_id, version); - - - -CREATE UNIQUE INDEX erased_users_user ON erased_users USING btree (user_id); - - - -CREATE INDEX ev_b_extrem_id ON event_backward_extremities USING btree (event_id); - - - -CREATE INDEX ev_b_extrem_room ON event_backward_extremities USING btree (room_id); - - - -CREATE INDEX ev_edges_id ON event_edges USING btree (event_id); - - - -CREATE INDEX ev_edges_prev_id ON event_edges USING btree (prev_event_id); - - - -CREATE INDEX ev_extrem_id ON event_forward_extremities USING btree (event_id); - - - -CREATE INDEX ev_extrem_room ON event_forward_extremities USING btree (room_id); - - - -CREATE INDEX evauth_edges_id ON event_auth USING btree (event_id); - - - -CREATE INDEX event_contains_url_index ON events USING btree (room_id, topological_ordering, stream_ordering) WHERE ((contains_url = true) AND (outlier = false)); - - - -CREATE INDEX event_json_room_id ON event_json USING btree (room_id); - - - -CREATE INDEX event_push_actions_highlights_index ON event_push_actions USING btree (user_id, room_id, topological_ordering, stream_ordering) WHERE (highlight = 1); - - - -CREATE INDEX event_push_actions_rm_tokens ON event_push_actions USING btree (user_id, room_id, topological_ordering, stream_ordering); - - - -CREATE INDEX event_push_actions_room_id_user_id ON event_push_actions USING btree (room_id, user_id); - - - -CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging USING btree (event_id); - - - -CREATE INDEX event_push_actions_stream_ordering ON event_push_actions USING btree (stream_ordering, user_id); - - - -CREATE INDEX event_push_actions_u_highlight ON event_push_actions USING btree (user_id, stream_ordering); - - - -CREATE INDEX event_push_summary_user_rm ON event_push_summary USING btree (user_id, room_id); - - - -CREATE INDEX event_reference_hashes_id ON event_reference_hashes USING btree (event_id); - - - -CREATE UNIQUE INDEX event_relations_id ON event_relations USING btree (event_id); - - - -CREATE INDEX event_relations_relates ON event_relations USING btree (relates_to_id, relation_type, aggregation_key); - - - -CREATE INDEX event_search_ev_ridx ON event_search USING btree (room_id); - - - -CREATE UNIQUE INDEX event_search_event_id_idx ON event_search USING btree (event_id); - - - -CREATE INDEX event_search_fts_idx ON event_search USING gin (vector); - - - -CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups USING btree (state_group); - - - -CREATE INDEX events_order_room ON events USING btree (room_id, topological_ordering, stream_ordering); - - - -CREATE INDEX events_room_stream ON events USING btree (room_id, stream_ordering); - - - -CREATE INDEX events_ts ON events USING btree (origin_server_ts, stream_ordering); - - - -CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote USING btree (group_id, user_id); - - - -CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote USING btree (user_id); - - - -CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote USING btree (valid_until_ms); - - - -CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals USING btree (group_id, user_id); - - - -CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals USING btree (user_id); - - - -CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals USING btree (valid_until_ms); - - - -CREATE UNIQUE INDEX group_invites_g_idx ON group_invites USING btree (group_id, user_id); - - - -CREATE INDEX group_invites_u_idx ON group_invites USING btree (user_id); - - - -CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms USING btree (group_id, room_id); - - - -CREATE INDEX group_rooms_r_idx ON group_rooms USING btree (room_id); - - - -CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms USING btree (group_id, room_id, category_id); - - - -CREATE INDEX group_summary_users_g_idx ON group_summary_users USING btree (group_id); - - - -CREATE UNIQUE INDEX group_users_g_idx ON group_users USING btree (group_id, user_id); - - - -CREATE INDEX group_users_u_idx ON group_users USING btree (user_id); - - - -CREATE UNIQUE INDEX groups_idx ON groups USING btree (group_id); - - - -CREATE INDEX local_group_membership_g_idx ON local_group_membership USING btree (group_id); - - - -CREATE INDEX local_group_membership_u_idx ON local_group_membership USING btree (user_id, group_id); - - - -CREATE INDEX local_invites_for_user_idx ON local_invites USING btree (invitee, locally_rejected, replaced_by, room_id); - - - -CREATE INDEX local_invites_id ON local_invites USING btree (stream_id); - - - -CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails USING btree (media_id); - - - -CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache USING btree (url, download_ts); - - - -CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache USING btree (expires_ts); - - - -CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache USING btree (media_id); - - - -CREATE INDEX local_media_repository_url_idx ON local_media_repository USING btree (created_ts) WHERE (url_cache IS NOT NULL); - - - -CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users USING btree ("timestamp"); - - - -CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users USING btree (user_id); - - - -CREATE INDEX open_id_tokens_ts_valid_until_ms ON open_id_tokens USING btree (ts_valid_until_ms); - - - -CREATE INDEX presence_stream_id ON presence_stream USING btree (stream_id, user_id); - - - -CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id); - - - -CREATE INDEX public_room_index ON rooms USING btree (is_public); - - - -CREATE INDEX public_room_list_stream_idx ON public_room_list_stream USING btree (stream_id); - - - -CREATE INDEX public_room_list_stream_rm_idx ON public_room_list_stream USING btree (room_id, stream_id); - - - -CREATE INDEX push_rules_enable_user_name ON push_rules_enable USING btree (user_name); - - - -CREATE INDEX push_rules_stream_id ON push_rules_stream USING btree (stream_id); - - - -CREATE INDEX push_rules_stream_user_stream_id ON push_rules_stream USING btree (user_id, stream_id); - - - -CREATE INDEX push_rules_user_name ON push_rules USING btree (user_name); - - - -CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override USING btree (user_id); - - - -CREATE INDEX receipts_linearized_id ON receipts_linearized USING btree (stream_id); - - - -CREATE INDEX receipts_linearized_room_stream ON receipts_linearized USING btree (room_id, stream_id); - - - -CREATE INDEX receipts_linearized_user ON receipts_linearized USING btree (user_id); - - - -CREATE INDEX received_transactions_ts ON received_transactions USING btree (ts); - - - -CREATE INDEX redactions_redacts ON redactions USING btree (redacts); - - - -CREATE INDEX remote_profile_cache_time ON remote_profile_cache USING btree (last_check); - - - -CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache USING btree (user_id); - - - -CREATE INDEX room_account_data_stream_id ON room_account_data USING btree (user_id, stream_id); - - - -CREATE INDEX room_alias_servers_alias ON room_alias_servers USING btree (room_alias); - - - -CREATE INDEX room_aliases_id ON room_aliases USING btree (room_id); - - - -CREATE INDEX room_depth_room ON room_depth USING btree (room_id); - - - -CREATE INDEX room_memberships_room_id ON room_memberships USING btree (room_id); - - - -CREATE INDEX room_memberships_user_id ON room_memberships USING btree (user_id); - - - -CREATE INDEX room_names_room_id ON room_names USING btree (room_id); - - - -CREATE UNIQUE INDEX room_state_room ON room_state USING btree (room_id); - - - -CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token USING btree (room_id); - - - -CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts); - - - -CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); - - - -CREATE INDEX stream_ordering_to_exterm_rm_idx ON stream_ordering_to_exterm USING btree (room_id, stream_ordering); - - - -CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens USING btree (medium, address); - - - -CREATE INDEX topics_room_id ON topics USING btree (room_id); - - - -CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits USING btree ("timestamp"); - - - -CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits USING btree (user_id, "timestamp"); - - - -CREATE INDEX user_directory_room_idx ON user_directory USING btree (room_id); - - - -CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin (vector); - - - -CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search USING btree (user_id); - - - -CREATE UNIQUE INDEX user_directory_user_idx ON user_directory USING btree (user_id); - - - -CREATE INDEX user_filters_by_user_id_filter_id ON user_filters USING btree (user_id, filter_id); - - - -CREATE INDEX user_ips_device_id ON user_ips USING btree (user_id, device_id, last_seen); - - - -CREATE INDEX user_ips_last_seen ON user_ips USING btree (user_id, last_seen); - - - -CREATE INDEX user_ips_last_seen_only ON user_ips USING btree (last_seen); - - - -CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips USING btree (user_id, access_token, ip); - - - -CREATE UNIQUE INDEX user_stats_user_ts ON user_stats USING btree (user_id, ts); - - - -CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server USING btree (user_id, medium, address, id_server); - - - -CREATE INDEX user_threepids_medium_address ON user_threepids USING btree (medium, address); - - - -CREATE INDEX user_threepids_user_id ON user_threepids USING btree (user_id); - - - -CREATE INDEX users_creation_ts ON users USING btree (creation_ts); - - - -CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms USING btree (user_id, room_id); - - - -CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms USING btree (other_user_id); - - - -CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms USING btree (room_id); - - - -CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite deleted file mode 100644 index a0411ede7e..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite +++ /dev/null @@ -1,253 +0,0 @@ -CREATE TABLE application_services_state( as_id TEXT PRIMARY KEY, state VARCHAR(5), last_txn INTEGER ); -CREATE TABLE application_services_txns( as_id TEXT NOT NULL, txn_id INTEGER NOT NULL, event_ids TEXT NOT NULL, UNIQUE(as_id, txn_id) ); -CREATE INDEX application_services_txns_id ON application_services_txns ( as_id ); -CREATE TABLE presence( user_id TEXT NOT NULL, state VARCHAR(20), status_msg TEXT, mtime BIGINT, UNIQUE (user_id) ); -CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_user_id TEXT NOT NULL, UNIQUE (observed_user_id, observer_user_id) ); -CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) ); -CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) ); -CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL ); -CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) ); -CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) ); -CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER ); -CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) ); -CREATE INDEX events_order_room ON events ( room_id, topological_ordering, stream_ordering ); -CREATE TABLE event_json( event_id TEXT NOT NULL, room_id TEXT NOT NULL, internal_metadata TEXT NOT NULL, json TEXT NOT NULL, format_version INTEGER, UNIQUE (event_id) ); -CREATE INDEX event_json_room_id ON event_json(room_id); -CREATE TABLE state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, prev_state TEXT, UNIQUE (event_id) ); -CREATE TABLE current_state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, UNIQUE (event_id), UNIQUE (room_id, type, state_key) ); -CREATE TABLE room_memberships( event_id TEXT NOT NULL, user_id TEXT NOT NULL, sender TEXT NOT NULL, room_id TEXT NOT NULL, membership TEXT NOT NULL, forgotten INTEGER DEFAULT 0, display_name TEXT, avatar_url TEXT, UNIQUE (event_id) ); -CREATE INDEX room_memberships_room_id ON room_memberships (room_id); -CREATE INDEX room_memberships_user_id ON room_memberships (user_id); -CREATE TABLE topics( event_id TEXT NOT NULL, room_id TEXT NOT NULL, topic TEXT NOT NULL, UNIQUE (event_id) ); -CREATE INDEX topics_room_id ON topics(room_id); -CREATE TABLE room_names( event_id TEXT NOT NULL, room_id TEXT NOT NULL, name TEXT NOT NULL, UNIQUE (event_id) ); -CREATE INDEX room_names_room_id ON room_names(room_id); -CREATE TABLE rooms( room_id TEXT PRIMARY KEY NOT NULL, is_public BOOL, creator TEXT ); -CREATE TABLE server_signature_keys( server_name TEXT, key_id TEXT, from_server TEXT, ts_added_ms BIGINT, verify_key bytea, ts_valid_until_ms BIGINT, UNIQUE (server_name, key_id) ); -CREATE TABLE rejections( event_id TEXT NOT NULL, reason TEXT NOT NULL, last_check TEXT NOT NULL, UNIQUE (event_id) ); -CREATE TABLE push_rules ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, rule_id TEXT NOT NULL, priority_class SMALLINT NOT NULL, priority INTEGER NOT NULL DEFAULT 0, conditions TEXT NOT NULL, actions TEXT NOT NULL, UNIQUE(user_name, rule_id) ); -CREATE INDEX push_rules_user_name on push_rules (user_name); -CREATE TABLE user_filters( user_id TEXT, filter_id BIGINT, filter_json bytea ); -CREATE INDEX user_filters_by_user_id_filter_id ON user_filters( user_id, filter_id ); -CREATE TABLE push_rules_enable ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, rule_id TEXT NOT NULL, enabled SMALLINT, UNIQUE(user_name, rule_id) ); -CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name); -CREATE TABLE event_forward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, UNIQUE (event_id, room_id) ); -CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); -CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); -CREATE TABLE event_backward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, UNIQUE (event_id, room_id) ); -CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); -CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); -CREATE TABLE event_edges( event_id TEXT NOT NULL, prev_event_id TEXT NOT NULL, room_id TEXT NOT NULL, is_state BOOL NOT NULL, UNIQUE (event_id, prev_event_id, room_id, is_state) ); -CREATE INDEX ev_edges_id ON event_edges(event_id); -CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); -CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); -CREATE INDEX room_depth_room ON room_depth(room_id); -CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); -CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) ); -CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); -CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id); -CREATE TABLE remote_media_cache ( media_origin TEXT, media_id TEXT, media_type TEXT, created_ts BIGINT, upload_name TEXT, media_length INTEGER, filesystem_id TEXT, last_access_ts BIGINT, quarantined_by TEXT, UNIQUE (media_origin, media_id) ); -CREATE TABLE remote_media_cache_thumbnails ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); -CREATE TABLE redactions ( event_id TEXT NOT NULL, redacts TEXT NOT NULL, UNIQUE (event_id) ); -CREATE INDEX redactions_redacts ON redactions (redacts); -CREATE TABLE room_aliases( room_alias TEXT NOT NULL, room_id TEXT NOT NULL, creator TEXT, UNIQUE (room_alias) ); -CREATE INDEX room_aliases_id ON room_aliases(room_id); -CREATE TABLE room_alias_servers( room_alias TEXT NOT NULL, server TEXT NOT NULL ); -CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias); -CREATE TABLE event_reference_hashes ( event_id TEXT, algorithm TEXT, hash bytea, UNIQUE (event_id, algorithm) ); -CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); -CREATE TABLE IF NOT EXISTS "server_keys_json" ( server_name TEXT NOT NULL, key_id TEXT NOT NULL, from_server TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, ts_valid_until_ms BIGINT NOT NULL, key_json bytea NOT NULL, CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server) ); -CREATE TABLE e2e_device_keys_json ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, key_json TEXT NOT NULL, CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) ); -CREATE TABLE e2e_one_time_keys_json ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, algorithm TEXT NOT NULL, key_id TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, key_json TEXT NOT NULL, CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) ); -CREATE TABLE receipts_graph( room_id TEXT NOT NULL, receipt_type TEXT NOT NULL, user_id TEXT NOT NULL, event_ids TEXT NOT NULL, data TEXT NOT NULL, CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id) ); -CREATE TABLE receipts_linearized ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, receipt_type TEXT NOT NULL, user_id TEXT NOT NULL, event_id TEXT NOT NULL, data TEXT NOT NULL, CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id) ); -CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); -CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); -CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) ); -CREATE INDEX user_threepids_user_id ON user_threepids(user_id); -CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value ) -/* event_search(event_id,room_id,sender,"key",value) */; -CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value'); -CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); -CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); -CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); -CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB); -CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) ); -CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) ); -CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) ); -CREATE TABLE room_tags_revisions ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, stream_id BIGINT NOT NULL, CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) ); -CREATE TABLE IF NOT EXISTS "account_data_max_stream_id"( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT NOT NULL, CHECK (Lock='X') ); -CREATE TABLE account_data( user_id TEXT NOT NULL, account_data_type TEXT NOT NULL, stream_id BIGINT NOT NULL, content TEXT NOT NULL, CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) ); -CREATE TABLE room_account_data( user_id TEXT NOT NULL, room_id TEXT NOT NULL, account_data_type TEXT NOT NULL, stream_id BIGINT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) ); -CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); -CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); -CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering); -CREATE TABLE event_push_actions( room_id TEXT NOT NULL, event_id TEXT NOT NULL, user_id TEXT NOT NULL, profile_tag VARCHAR(32), actions TEXT NOT NULL, topological_ordering BIGINT, stream_ordering BIGINT, notif SMALLINT, highlight SMALLINT, CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) ); -CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); -CREATE INDEX events_room_stream on events(room_id, stream_ordering); -CREATE INDEX public_room_index on rooms(is_public); -CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); -CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering ); -CREATE TABLE presence_stream( stream_id BIGINT, user_id TEXT, state TEXT, last_active_ts BIGINT, last_federation_update_ts BIGINT, last_user_sync_ts BIGINT, status_msg TEXT, currently_active BOOLEAN ); -CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); -CREATE INDEX presence_stream_user_id ON presence_stream(user_id); -CREATE TABLE push_rules_stream( stream_id BIGINT NOT NULL, event_stream_ordering BIGINT NOT NULL, user_id TEXT NOT NULL, rule_id TEXT NOT NULL, op TEXT NOT NULL, priority_class SMALLINT, priority INTEGER, conditions TEXT, actions TEXT ); -CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); -CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); -CREATE TABLE ex_outlier_stream( event_stream_ordering BIGINT PRIMARY KEY NOT NULL, event_id TEXT NOT NULL, state_group BIGINT NOT NULL ); -CREATE TABLE threepid_guest_access_tokens( medium TEXT, address TEXT, guest_access_token TEXT, first_inviter TEXT ); -CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); -CREATE TABLE local_invites( stream_id BIGINT NOT NULL, inviter TEXT NOT NULL, invitee TEXT NOT NULL, event_id TEXT NOT NULL, room_id TEXT NOT NULL, locally_rejected TEXT, replaced_by TEXT ); -CREATE INDEX local_invites_id ON local_invites(stream_id); -CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); -CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id ); -CREATE TABLE open_id_tokens ( token TEXT NOT NULL PRIMARY KEY, ts_valid_until_ms bigint NOT NULL, user_id TEXT NOT NULL, UNIQUE (token) ); -CREATE INDEX open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms); -CREATE TABLE pusher_throttle( pusher BIGINT NOT NULL, room_id TEXT NOT NULL, last_sent_ts BIGINT, throttle_ms BIGINT, PRIMARY KEY (pusher, room_id) ); -CREATE TABLE event_reports( id BIGINT NOT NULL PRIMARY KEY, received_ts BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, user_id TEXT NOT NULL, reason TEXT, content TEXT ); -CREATE TABLE devices ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, display_name TEXT, CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) ); -CREATE TABLE appservice_stream_position( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_ordering BIGINT, CHECK (Lock='X') ); -CREATE TABLE device_inbox ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, stream_id BIGINT NOT NULL, message_json TEXT NOT NULL ); -CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); -CREATE INDEX received_transactions_ts ON received_transactions(ts); -CREATE TABLE device_federation_outbox ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, queued_ts BIGINT NOT NULL, messages_json TEXT NOT NULL ); -CREATE INDEX device_federation_outbox_destination_id ON device_federation_outbox(destination, stream_id); -CREATE TABLE device_federation_inbox ( origin TEXT NOT NULL, message_id TEXT NOT NULL, received_ts BIGINT NOT NULL ); -CREATE INDEX device_federation_inbox_sender_id ON device_federation_inbox(origin, message_id); -CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL ); -CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT); -CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id ); -CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id ); -CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); -CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); -CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); -CREATE TABLE IF NOT EXISTS "event_auth"( event_id TEXT NOT NULL, auth_id TEXT NOT NULL, room_id TEXT NOT NULL ); -CREATE INDEX evauth_edges_id ON event_auth(event_id); -CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); -CREATE TABLE appservice_room_list( appservice_id TEXT NOT NULL, network_id TEXT NOT NULL, room_id TEXT NOT NULL ); -CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( appservice_id, network_id, room_id ); -CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); -CREATE TABLE federation_stream_position( type TEXT NOT NULL, stream_id INTEGER NOT NULL ); -CREATE TABLE device_lists_remote_cache ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, content TEXT NOT NULL ); -CREATE TABLE device_lists_remote_extremeties ( user_id TEXT NOT NULL, stream_id TEXT NOT NULL ); -CREATE TABLE device_lists_stream ( stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL ); -CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); -CREATE TABLE device_lists_outbound_pokes ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL, sent BOOLEAN NOT NULL, ts BIGINT NOT NULL ); -CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); -CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); -CREATE TABLE event_push_summary ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, notif_count BIGINT NOT NULL, stream_ordering BIGINT NOT NULL ); -CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); -CREATE TABLE event_push_summary_stream_ordering ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_ordering BIGINT NOT NULL, CHECK (Lock='X') ); -CREATE TABLE IF NOT EXISTS "pushers" ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, access_token BIGINT DEFAULT NULL, profile_tag TEXT NOT NULL, kind TEXT NOT NULL, app_id TEXT NOT NULL, app_display_name TEXT NOT NULL, device_display_name TEXT NOT NULL, pushkey TEXT NOT NULL, ts BIGINT NOT NULL, lang TEXT, data TEXT, last_stream_ordering INTEGER, last_success BIGINT, failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ); -CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); -CREATE TABLE ratelimit_override ( user_id TEXT NOT NULL, messages_per_second BIGINT, burst_count BIGINT ); -CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); -CREATE TABLE current_state_delta_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT, prev_event_id TEXT ); -CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); -CREATE TABLE device_lists_outbound_last_success ( destination TEXT NOT NULL, user_id TEXT NOT NULL, stream_id BIGINT NOT NULL ); -CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( destination, user_id, stream_id ); -CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); -CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value ) -/* user_directory_search(user_id,value) */; -CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value'); -CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); -CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); -CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); -CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB); -CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL ); -CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); -CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT ); -CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); -CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); -CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); -CREATE TABLE group_users ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, is_admin BOOLEAN NOT NULL, is_public BOOLEAN NOT NULL ); -CREATE TABLE group_invites ( group_id TEXT NOT NULL, user_id TEXT NOT NULL ); -CREATE TABLE group_rooms ( group_id TEXT NOT NULL, room_id TEXT NOT NULL, is_public BOOLEAN NOT NULL ); -CREATE TABLE group_summary_rooms ( group_id TEXT NOT NULL, room_id TEXT NOT NULL, category_id TEXT NOT NULL, room_order BIGINT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, category_id, room_id, room_order), CHECK (room_order > 0) ); -CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); -CREATE TABLE group_summary_room_categories ( group_id TEXT NOT NULL, category_id TEXT NOT NULL, cat_order BIGINT NOT NULL, UNIQUE (group_id, category_id, cat_order), CHECK (cat_order > 0) ); -CREATE TABLE group_room_categories ( group_id TEXT NOT NULL, category_id TEXT NOT NULL, profile TEXT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, category_id) ); -CREATE TABLE group_summary_users ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, role_id TEXT NOT NULL, user_order BIGINT NOT NULL, is_public BOOLEAN NOT NULL ); -CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); -CREATE TABLE group_summary_roles ( group_id TEXT NOT NULL, role_id TEXT NOT NULL, role_order BIGINT NOT NULL, UNIQUE (group_id, role_id, role_order), CHECK (role_order > 0) ); -CREATE TABLE group_roles ( group_id TEXT NOT NULL, role_id TEXT NOT NULL, profile TEXT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, role_id) ); -CREATE TABLE group_attestations_renewals ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, valid_until_ms BIGINT NOT NULL ); -CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); -CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); -CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); -CREATE TABLE group_attestations_remote ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, valid_until_ms BIGINT NOT NULL, attestation_json TEXT NOT NULL ); -CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); -CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); -CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); -CREATE TABLE local_group_membership ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, is_admin BOOLEAN NOT NULL, membership TEXT NOT NULL, is_publicised BOOLEAN NOT NULL, content TEXT NOT NULL ); -CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); -CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); -CREATE TABLE local_group_updates ( stream_id BIGINT NOT NULL, group_id TEXT NOT NULL, user_id TEXT NOT NULL, type TEXT NOT NULL, content TEXT NOT NULL ); -CREATE TABLE remote_profile_cache ( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, last_check BIGINT NOT NULL ); -CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); -CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); -CREATE TABLE IF NOT EXISTS "deleted_pushers" ( stream_id BIGINT NOT NULL, app_id TEXT NOT NULL, pushkey TEXT NOT NULL, user_id TEXT NOT NULL ); -CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); -CREATE TABLE IF NOT EXISTS "groups" ( group_id TEXT NOT NULL, name TEXT, avatar_url TEXT, short_description TEXT, long_description TEXT, is_public BOOL NOT NULL , join_policy TEXT NOT NULL DEFAULT 'invite'); -CREATE UNIQUE INDEX groups_idx ON groups(group_id); -CREATE TABLE IF NOT EXISTS "user_directory" ( user_id TEXT NOT NULL, room_id TEXT, display_name TEXT, avatar_url TEXT ); -CREATE INDEX user_directory_room_idx ON user_directory(room_id); -CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); -CREATE TABLE event_push_actions_staging ( event_id TEXT NOT NULL, user_id TEXT NOT NULL, actions TEXT NOT NULL, notif SMALLINT NOT NULL, highlight SMALLINT NOT NULL ); -CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); -CREATE TABLE users_pending_deactivation ( user_id TEXT NOT NULL ); -CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); -CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); -CREATE INDEX group_users_u_idx ON group_users(user_id); -CREATE INDEX group_invites_u_idx ON group_invites(user_id); -CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); -CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); -CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL ); -CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); -CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); -CREATE TABLE erased_users ( user_id TEXT NOT NULL ); -CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id); -CREATE TABLE monthly_active_users ( user_id TEXT NOT NULL, timestamp BIGINT NOT NULL ); -CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); -CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); -CREATE TABLE IF NOT EXISTS "e2e_room_keys_versions" ( user_id TEXT NOT NULL, version BIGINT NOT NULL, algorithm TEXT NOT NULL, auth_data TEXT NOT NULL, deleted SMALLINT DEFAULT 0 NOT NULL ); -CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); -CREATE TABLE IF NOT EXISTS "e2e_room_keys" ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, version BIGINT NOT NULL, first_message_index INT, forwarded_count INT, is_verified BOOLEAN, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); -CREATE TABLE users_who_share_private_rooms ( user_id TEXT NOT NULL, other_user_id TEXT NOT NULL, room_id TEXT NOT NULL ); -CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id); -CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id); -CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id); -CREATE TABLE user_threepid_id_server ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, id_server TEXT NOT NULL ); -CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server( user_id, medium, address, id_server ); -CREATE TABLE users_in_public_rooms ( user_id TEXT NOT NULL, room_id TEXT NOT NULL ); -CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id); -CREATE TABLE account_validity ( user_id TEXT PRIMARY KEY, expiration_ts_ms BIGINT NOT NULL, email_sent BOOLEAN NOT NULL, renewal_token TEXT ); -CREATE TABLE event_relations ( event_id TEXT NOT NULL, relates_to_id TEXT NOT NULL, relation_type TEXT NOT NULL, aggregation_key TEXT ); -CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); -CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); -CREATE TABLE stats_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); -CREATE TABLE user_stats ( user_id TEXT NOT NULL, ts BIGINT NOT NULL, bucket_size INT NOT NULL, public_rooms INT NOT NULL, private_rooms INT NOT NULL ); -CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); -CREATE TABLE room_stats ( room_id TEXT NOT NULL, ts BIGINT NOT NULL, bucket_size INT NOT NULL, current_state_events INT NOT NULL, joined_members INT NOT NULL, invited_members INT NOT NULL, left_members INT NOT NULL, banned_members INT NOT NULL, state_events INT NOT NULL ); -CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); -CREATE TABLE room_state ( room_id TEXT NOT NULL, join_rules TEXT, history_visibility TEXT, encryption TEXT, name TEXT, topic TEXT, avatar TEXT, canonical_alias TEXT ); -CREATE UNIQUE INDEX room_state_room ON room_state(room_id); -CREATE TABLE room_stats_earliest_token ( room_id TEXT NOT NULL, token BIGINT NOT NULL ); -CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); -CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id); -CREATE INDEX user_ips_device_id ON user_ips (user_id, device_id, last_seen); -CREATE INDEX event_contains_url_index ON events (room_id, topological_ordering, stream_ordering); -CREATE INDEX event_push_actions_u_highlight ON event_push_actions (user_id, stream_ordering); -CREATE INDEX event_push_actions_highlights_index ON event_push_actions (user_id, room_id, topological_ordering, stream_ordering); -CREATE INDEX current_state_events_member_index ON current_state_events (state_key); -CREATE INDEX device_inbox_stream_id_user_id ON device_inbox (stream_id, user_id); -CREATE INDEX device_lists_stream_user_id ON device_lists_stream (user_id, device_id); -CREATE INDEX local_media_repository_url_idx ON local_media_repository (created_ts); -CREATE INDEX user_ips_last_seen ON user_ips (user_id, last_seen); -CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); -CREATE INDEX users_creation_ts ON users (creation_ts); -CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); -CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); -CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); -CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql deleted file mode 100644 index 91d21b2921..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql +++ /dev/null @@ -1,8 +0,0 @@ - -INSERT INTO appservice_stream_position (stream_ordering) SELECT COALESCE(MAX(stream_ordering), 0) FROM events; -INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); -INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; -INSERT INTO user_directory_stream_pos (stream_id) VALUES (0); -INSERT INTO stats_stream_pos (stream_id) VALUES (0); -INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); --- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md deleted file mode 100644 index c00f287190..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# Synapse Database Schemas - -These schemas are used as a basis to create brand new Synapse databases, on both -SQLite3 and Postgres. - -## Building full schema dumps - -If you want to recreate these schemas, they need to be made from a database that -has had all background updates run. - -To do so, use `scripts-dev/make_full_schema.sh`. This will produce new -`full.sql.postgres ` and `full.sql.sqlite` files. - -Ensure postgres is installed and your user has the ability to run bash commands -such as `createdb`, then call - - ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/ - -There are currently two folders with full-schema snapshots. `16` is a snapshot -from 2015, for historical reference. The other contains the most recent full -schema snapshot. diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py deleted file mode 100644 index d52228297c..0000000000 --- a/synapse/storage/data_stores/main/search.py +++ /dev/null @@ -1,708 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -from collections import namedtuple - -from twisted.internet import defer - -from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour -from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine, Sqlite3Engine - -logger = logging.getLogger(__name__) - -SearchEntry = namedtuple( - "SearchEntry", - ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], -) - - -class SearchWorkerStore(SQLBaseStore): - def store_search_entries_txn(self, txn, entries): - """Add entries to the search table - - Args: - txn (cursor): - entries (iterable[SearchEntry]): - entries to be added to the table - """ - if not self.hs.config.enable_search: - return - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) - - args = ( - ( - entry.event_id, - entry.room_id, - entry.key, - entry.value, - entry.stream_ordering, - entry.origin_server_ts, - ) - for entry in entries - ) - - txn.executemany(sql, args) - - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args = ( - (entry.event_id, entry.room_id, entry.key, entry.value) - for entry in entries - ) - - txn.executemany(sql, args) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - -class SearchBackgroundUpdateStore(SearchWorkerStore): - - EVENT_SEARCH_UPDATE_NAME = "event_search" - EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" - EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" - EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - - def __init__(self, database: Database, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) - - if not hs.config.enable_search: - return - - self.db.updates.register_background_update_handler( - self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search - ) - self.db.updates.register_background_update_handler( - self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order - ) - - # we used to have a background update to turn the GIN index into a - # GIST one; we no longer do that (obviously) because we actually want - # a GIN index. However, it's possible that some people might still have - # the background update queued, so we register a handler to clear the - # background update. - self.db.updates.register_noop_background_update( - self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME - ) - - self.db.updates.register_background_update_handler( - self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search - ) - - @defer.inlineCallbacks - def _background_reindex_search(self, progress, batch_size): - # we work through the events table from highest stream id to lowest - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - TYPES = ["m.room.name", "m.room.message", "m.room.topic"] - - def reindex_search_txn(txn): - sql = ( - "SELECT stream_ordering, event_id, room_id, type, json, " - " origin_server_ts FROM events" - " JOIN event_json USING (room_id, event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " AND (%s)" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - # we could stream straight from the results into - # store_search_entries_txn with a generator function, but that - # would mean having two cursors open on the database at once. - # Instead we just build a list of results. - rows = self.db.cursor_to_dict(txn) - if not rows: - return 0 - - min_stream_id = rows[-1]["stream_ordering"] - - event_search_rows = [] - for row in rows: - try: - event_id = row["event_id"] - room_id = row["room_id"] - etype = row["type"] - stream_ordering = row["stream_ordering"] - origin_server_ts = row["origin_server_ts"] - try: - event_json = db_to_json(row["json"]) - content = event_json["content"] - except Exception: - continue - - if etype == "m.room.message": - key = "content.body" - value = content["body"] - elif etype == "m.room.topic": - key = "content.topic" - value = content["topic"] - elif etype == "m.room.name": - key = "content.name" - value = content["name"] - else: - raise Exception("unexpected event type %s" % etype) - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - if not isinstance(value, str): - # If the event body, name or topic isn't a string - # then skip over it - continue - - event_search_rows.append( - SearchEntry( - key=key, - value=value, - event_id=event_id, - room_id=room_id, - stream_ordering=stream_ordering, - origin_server_ts=origin_server_ts, - ) - ) - - self.store_search_entries_txn(txn, event_search_rows) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(event_search_rows), - } - - self.db.updates._background_update_progress_txn( - txn, self.EVENT_SEARCH_UPDATE_NAME, progress - ) - - return len(event_search_rows) - - result = yield self.db.runInteraction( - self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn - ) - - if not result: - yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) - - return result - - @defer.inlineCallbacks - def _background_reindex_gin_search(self, progress, batch_size): - """This handles old synapses which used GIST indexes, if any; - converting them back to be GIN as per the actual schema. - """ - - def create_index(conn): - conn.rollback() - - # we have to set autocommit, because postgres refuses to - # CREATE INDEX CONCURRENTLY without it. - conn.set_session(autocommit=True) - - try: - c = conn.cursor() - - # if we skipped the conversion to GIST, we may already/still - # have an event_search_fts_idx; unfortunately postgres 9.4 - # doesn't support CREATE INDEX IF EXISTS so we just catch the - # exception and ignore it. - import psycopg2 - - try: - c.execute( - "CREATE INDEX CONCURRENTLY event_search_fts_idx" - " ON event_search USING GIN (vector)" - ) - except psycopg2.ProgrammingError as e: - logger.warning( - "Ignoring error %r when trying to switch from GIST to GIN", e - ) - - # we should now be able to delete the GIST index. - c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist") - finally: - conn.set_session(autocommit=False) - - if isinstance(self.database_engine, PostgresEngine): - yield self.db.runWithConnection(create_index) - - yield self.db.updates._end_background_update( - self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME - ) - return 1 - - @defer.inlineCallbacks - def _background_reindex_search_order(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - have_added_index = progress["have_added_indexes"] - - if not have_added_index: - - def create_index(conn): - conn.rollback() - conn.set_session(autocommit=True) - c = conn.cursor() - - # We create with NULLS FIRST so that when we search *backwards* - # we get the ones with non null origin_server_ts *first* - c.execute( - "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search(" - "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" - ) - c.execute( - "CREATE INDEX CONCURRENTLY event_search_order ON event_search(" - "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" - ) - conn.set_session(autocommit=False) - - yield self.db.runWithConnection(create_index) - - pg = dict(progress) - pg["have_added_indexes"] = True - - yield self.db.runInteraction( - self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self.db.updates._background_update_progress_txn, - self.EVENT_SEARCH_ORDER_UPDATE_NAME, - pg, - ) - - def reindex_search_txn(txn): - sql = ( - "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," - " origin_server_ts = e.origin_server_ts" - " FROM events AS e" - " WHERE e.event_id = es.event_id" - " AND ? <= e.stream_ordering AND e.stream_ordering < ?" - " RETURNING es.stream_ordering" - ) - - min_stream_id = max_stream_id - batch_size - txn.execute(sql, (min_stream_id, max_stream_id)) - rows = txn.fetchall() - - if min_stream_id < target_min_stream_id: - # We've recached the end. - return len(rows), False - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows), - "have_added_indexes": True, - } - - self.db.updates._background_update_progress_txn( - txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress - ) - - return len(rows), True - - num_rows, finished = yield self.db.runInteraction( - self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn - ) - - if not finished: - yield self.db.updates._end_background_update( - self.EVENT_SEARCH_ORDER_UPDATE_NAME - ) - - return num_rows - - -class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): - super(SearchStore, self).__init__(database, db_conn, hs) - - @defer.inlineCallbacks - def search_msgs(self, room_ids, search_term, keys): - """Performs a full text search over events with given keys. - - Args: - room_ids (list): List of room ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports - "content.body", "content.name", "content.topic" - - Returns: - list of dicts - """ - clauses = [] - - search_query = _parse_query(self.database_engine, search_term) - - args = [] - - # Make sure we don't explode because the person is in too many rooms. - # We filter the results below regardless. - if len(room_ids) < 500: - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", room_ids - ) - clauses = [clause] - - local_clauses = [] - for key in keys: - local_clauses.append("key = ?") - args.append(key) - - clauses.append("(%s)" % (" OR ".join(local_clauses),)) - - count_args = args - count_clauses = clauses - - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," - " room_id, event_id" - " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" - ) - args = [search_query, search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" - ) - count_args = [search_query] + count_args - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" - " FROM event_search" - " WHERE value MATCH ?" - ) - args = [search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ?" - ) - count_args = [search_term] + count_args - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - for clause in clauses: - sql += " AND " + clause - - for clause in count_clauses: - count_sql += " AND " + clause - - # We add an arbitrary limit here to ensure we don't try to pull the - # entire table from the database. - sql += " ORDER BY rank DESC LIMIT 500" - - results = yield self.db.execute( - "search_msgs", self.db.cursor_to_dict, sql, *args - ) - - results = list(filter(lambda row: row["room_id"] in room_ids, results)) - - # We set redact_behaviour to BLOCK here to prevent redacted events being returned in - # search results (which is a data leak) - events = yield self.get_events_as_list( - [r["event_id"] for r in results], - redact_behaviour=EventRedactBehaviour.BLOCK, - ) - - event_map = {ev.event_id: ev for ev in events} - - highlights = None - if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) - - count_sql += " GROUP BY room_id" - - count_results = yield self.db.execute( - "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args - ) - - count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - - return { - "results": [ - {"event": event_map[r["event_id"]], "rank": r["rank"]} - for r in results - if r["event_id"] in event_map - ], - "highlights": highlights, - "count": count, - } - - @defer.inlineCallbacks - def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): - """Performs a full text search over events with given keys. - - Args: - room_id (list): The room_ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports - "content.body", "content.name", "content.topic" - pagination_token (str): A pagination token previously returned - - Returns: - list of dicts - """ - clauses = [] - - search_query = _parse_query(self.database_engine, search_term) - - args = [] - - # Make sure we don't explode because the person is in too many rooms. - # We filter the results below regardless. - if len(room_ids) < 500: - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", room_ids - ) - clauses = [clause] - - local_clauses = [] - for key in keys: - local_clauses.append("key = ?") - args.append(key) - - clauses.append("(%s)" % (" OR ".join(local_clauses),)) - - # take copies of the current args and clauses lists, before adding - # pagination clauses to main query. - count_args = list(args) - count_clauses = list(clauses) - - if pagination_token: - try: - origin_server_ts, stream = pagination_token.split(",") - origin_server_ts = int(origin_server_ts) - stream = int(stream) - except Exception: - raise SynapseError(400, "Invalid pagination token") - - clauses.append( - "(origin_server_ts < ?" - " OR (origin_server_ts = ? AND stream_ordering < ?))" - ) - args.extend([origin_server_ts, origin_server_ts, stream]) - - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," - " origin_server_ts, stream_ordering, room_id, event_id" - " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " - ) - args = [search_query, search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " - ) - count_args = [search_query] + count_args - elif isinstance(self.database_engine, Sqlite3Engine): - # We use CROSS JOIN here to ensure we use the right indexes. - # https://sqlite.org/optoverview.html#crossjoin - # - # We want to use the full text search index on event_search to - # extract all possible matches first, then lookup those matches - # in the events table to get the topological ordering. We need - # to use the indexes in this order because sqlite refuses to - # MATCH unless it uses the full text search index - sql = ( - "SELECT rank(matchinfo) as rank, room_id, event_id," - " origin_server_ts, stream_ordering" - " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" - " FROM event_search" - " WHERE value MATCH ?" - " )" - " CROSS JOIN events USING (event_id)" - " WHERE " - ) - args = [search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ? AND " - ) - count_args = [search_term] + count_args - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - sql += " AND ".join(clauses) - count_sql += " AND ".join(count_clauses) - - # We add an arbitrary limit here to ensure we don't try to pull the - # entire table from the database. - if isinstance(self.database_engine, PostgresEngine): - sql += ( - " ORDER BY origin_server_ts DESC NULLS LAST," - " stream_ordering DESC NULLS LAST LIMIT ?" - ) - elif isinstance(self.database_engine, Sqlite3Engine): - sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" - else: - raise Exception("Unrecognized database engine") - - args.append(limit) - - results = yield self.db.execute( - "search_rooms", self.db.cursor_to_dict, sql, *args - ) - - results = list(filter(lambda row: row["room_id"] in room_ids, results)) - - # We set redact_behaviour to BLOCK here to prevent redacted events being returned in - # search results (which is a data leak) - events = yield self.get_events_as_list( - [r["event_id"] for r in results], - redact_behaviour=EventRedactBehaviour.BLOCK, - ) - - event_map = {ev.event_id: ev for ev in events} - - highlights = None - if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) - - count_sql += " GROUP BY room_id" - - count_results = yield self.db.execute( - "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args - ) - - count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - - return { - "results": [ - { - "event": event_map[r["event_id"]], - "rank": r["rank"], - "pagination_token": "%s,%s" - % (r["origin_server_ts"], r["stream_ordering"]), - } - for r in results - if r["event_id"] in event_map - ], - "highlights": highlights, - "count": count, - } - - def _find_highlights_in_postgres(self, search_query, events): - """Given a list of events and a search term, return a list of words - that match from the content of the event. - - This is used to give a list of words that clients can match against to - highlight the matching parts. - - Args: - search_query (str) - events (list): A list of events - - Returns: - deferred : A set of strings. - """ - - def f(txn): - highlight_words = set() - for event in events: - # As a hack we simply join values of all possible keys. This is - # fine since we're only using them to find possible highlights. - values = [] - for key in ("body", "name", "topic"): - v = event.content.get(key, None) - if v: - values.append(v) - - if not values: - continue - - value = " ".join(values) - - # We need to find some values for StartSel and StopSel that - # aren't in the value so that we can pick results out. - start_sel = "<" - stop_sel = ">" - - while start_sel in value: - start_sel += "<" - while stop_sel in value: - stop_sel += ">" - - query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( - _to_postgres_options( - { - "StartSel": start_sel, - "StopSel": stop_sel, - "MaxFragments": "50", - } - ) - ) - txn.execute(query, (value, search_query)) - (headline,) = txn.fetchall()[0] - - # Now we need to pick the possible highlights out of the haedline - # result. - matcher_regex = "%s(.*?)%s" % ( - re.escape(start_sel), - re.escape(stop_sel), - ) - - res = re.findall(matcher_regex, headline) - highlight_words.update([r.lower() for r in res]) - - return highlight_words - - return self.db.runInteraction("_find_highlights", f) - - -def _to_postgres_options(options_dict): - return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) - - -def _parse_query(database_engine, search_term): - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. - """ - - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - - if isinstance(database_engine, PostgresEngine): - return " & ".join(result + ":*" for result in results) - elif isinstance(database_engine, Sqlite3Engine): - return " & ".join(result + "*" for result in results) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py deleted file mode 100644 index 36244d9f5d..0000000000 --- a/synapse/storage/data_stores/main/signatures.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unpaddedbase64 import encode_base64 - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList - - -class SignatureWorkerStore(SQLBaseStore): - @cached() - def get_event_reference_hash(self, event_id): - # This is a dummy function to allow get_event_reference_hashes - # to use its cache - raise NotImplementedError() - - @cachedList( - cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 - ) - def get_event_reference_hashes(self, event_ids): - def f(txn): - return { - event_id: self._get_event_reference_hashes_txn(txn, event_id) - for event_id in event_ids - } - - return self.db.runInteraction("get_event_reference_hashes", f) - - @defer.inlineCallbacks - def add_event_hashes(self, event_ids): - hashes = yield self.get_event_reference_hashes(event_ids) - hashes = { - e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} - for e_id, h in hashes.items() - } - - return list(hashes.items()) - - def _get_event_reference_hashes_txn(self, txn, event_id): - """Get all the hashes for a given PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict[unicode, bytes] of algorithm -> hash. - """ - query = ( - "SELECT algorithm, hash" - " FROM event_reference_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id,)) - return {k: v for k, v in txn} - - -class SignatureStore(SignatureWorkerStore): - """Persistence for event signatures and hashes""" diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py deleted file mode 100644 index a360699408..0000000000 --- a/synapse/storage/data_stores/main/state.py +++ /dev/null @@ -1,509 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import collections.abc -import logging -from collections import namedtuple -from typing import Iterable, Optional, Set - -from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion -from synapse.events import EventBase -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore -from synapse.storage.database import Database -from synapse.storage.state import StateFilter -from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedList - -logger = logging.getLogger(__name__) - - -MAX_STATE_DELTA_HOPS = 100 - - -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 - - -# this inherits from EventsWorkerStore because it calls self.get_events -class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): - """The parts of StateGroupStore that can be called from workers. - """ - - def __init__(self, database: Database, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) - - async def get_room_version(self, room_id: str) -> RoomVersion: - """Get the room_version of a given room - - Raises: - NotFoundError: if the room is unknown - - UnsupportedRoomVersionError: if the room uses an unknown room version. - Typically this happens if support for the room's version has been - removed from Synapse. - """ - room_version_id = await self.get_room_version_id(room_id) - v = KNOWN_ROOM_VERSIONS.get(room_version_id) - - if not v: - raise UnsupportedRoomVersionError( - "Room %s uses a room version %s which is no longer supported" - % (room_id, room_version_id) - ) - - return v - - @cached(max_entries=10000) - async def get_room_version_id(self, room_id: str) -> str: - """Get the room_version of a given room - - Raises: - NotFoundError: if the room is unknown - """ - - # First we try looking up room version from the database, but for old - # rooms we might not have added the room version to it yet so we fall - # back to previous behaviour and look in current state events. - - # We really should have an entry in the rooms table for every room we - # care about, but let's be a bit paranoid (at least while the background - # update is happening) to avoid breaking existing rooms. - version = await self.db.simple_select_one_onecol( - table="rooms", - keyvalues={"room_id": room_id}, - retcol="room_version", - desc="get_room_version", - allow_none=True, - ) - - if version is not None: - return version - - # Retrieve the room's create event - create_event = await self.get_create_event_for_room(room_id) - return create_event.content.get("room_version", "1") - - async def get_room_predecessor(self, room_id: str) -> Optional[dict]: - """Get the predecessor of an upgraded room if it exists. - Otherwise return None. - - Args: - room_id: The room ID. - - Returns: - A dictionary containing the structure of the predecessor - field from the room's create event. The structure is subject to other servers, - but it is expected to be: - * room_id (str): The room ID of the predecessor room - * event_id (str): The ID of the tombstone event in the predecessor room - - None if a predecessor key is not found, or is not a dictionary. - - Raises: - NotFoundError if the given room is unknown - """ - # Retrieve the room's create event - create_event = await self.get_create_event_for_room(room_id) - - # Retrieve the predecessor key of the create event - predecessor = create_event.content.get("predecessor", None) - - # Ensure the key is a dictionary - if not isinstance(predecessor, collections.abc.Mapping): - return None - - return predecessor - - async def get_create_event_for_room(self, room_id: str) -> EventBase: - """Get the create state event for a room. - - Args: - room_id: The room ID. - - Returns: - The room creation event. - - Raises: - NotFoundError if the room is unknown - """ - state_ids = await self.get_current_state_ids(room_id) - create_id = state_ids.get((EventTypes.Create, "")) - - # If we can't find the create event, assume we've hit a dead end - if not create_id: - raise NotFoundError("Unknown room %s" % (room_id,)) - - # Retrieve the room's create event and return - create_event = await self.get_event(create_id) - return create_event - - @cached(max_entries=100000, iterable=True) - def get_current_state_ids(self, room_id): - """Get the current state event ids for a room based on the - current_state_events table. - - Args: - room_id (str) - - Returns: - deferred: dict of (type, state_key) -> event_id - """ - - def _get_current_state_ids_txn(txn): - txn.execute( - """SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? - """, - (room_id,), - ) - - return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} - - return self.db.runInteraction( - "get_current_state_ids", _get_current_state_ids_txn - ) - - # FIXME: how should this be cached? - def get_filtered_current_state_ids( - self, room_id: str, state_filter: StateFilter = StateFilter.all() - ): - """Get the current state event of a given type for a room based on the - current_state_events table. This may not be as up-to-date as the result - of doing a fresh state resolution as per state_handler.get_current_state - - Args: - room_id - state_filter: The state filter used to fetch state - from the database. - - Returns: - defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. - """ - - where_clause, where_args = state_filter.make_sql_filter_clause() - - if not where_clause: - # We delegate to the cached version - return self.get_current_state_ids(room_id) - - def _get_filtered_current_state_ids_txn(txn): - results = {} - sql = """ - SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? - """ - - if where_clause: - sql += " AND (%s)" % (where_clause,) - - args = [room_id] - args.extend(where_args) - txn.execute(sql, args) - for row in txn: - typ, state_key, event_id = row - key = (intern_string(typ), intern_string(state_key)) - results[key] = event_id - - return results - - return self.db.runInteraction( - "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn - ) - - async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: - """Get canonical alias for room, if any - - Args: - room_id: The room ID - - Returns: - The canonical alias, if any - """ - - state = await self.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) - ) - - event_id = state.get((EventTypes.CanonicalAlias, "")) - if not event_id: - return - - event = await self.get_event(event_id, allow_none=True) - if not event: - return - - return event.content.get("canonical_alias") - - @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self.db.simple_select_one_onecol( - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - desc="_get_state_group_for_event", - ) - - @cachedList( - cached_method_name="_get_state_group_for_event", - list_name="event_ids", - num_args=1, - inlineCallbacks=True, - ) - def _get_state_group_for_events(self, event_ids): - """Returns mapping event_id -> state_group - """ - rows = yield self.db.simple_select_many_batch( - table="event_to_state_groups", - column="event_id", - iterable=event_ids, - keyvalues={}, - retcols=("event_id", "state_group"), - desc="_get_state_group_for_events", - ) - - return {row["event_id"]: row["state_group"] for row in rows} - - async def get_referenced_state_groups( - self, state_groups: Iterable[int] - ) -> Set[int]: - """Check if the state groups are referenced by events. - - Args: - state_groups - - Returns: - The subset of state groups that are referenced. - """ - - rows = await self.db.simple_select_many_batch( - table="event_to_state_groups", - column="state_group", - iterable=state_groups, - keyvalues={}, - retcols=("DISTINCT state_group",), - desc="get_referenced_state_groups", - ) - - return {row["state_group"] for row in rows} - - -class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): - - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" - DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" - - def __init__(self, database: Database, db_conn, hs): - super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) - - self.server_name = hs.hostname - - self.db.updates.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", - ) - self.db.updates.register_background_index_update( - self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, - index_name="event_to_state_groups_sg_index", - table="event_to_state_groups", - columns=["state_group"], - ) - self.db.updates.register_background_update_handler( - self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms, - ) - - async def _background_remove_left_rooms(self, progress, batch_size): - """Background update to delete rows from `current_state_events` and - `event_forward_extremities` tables of rooms that the server is no - longer joined to. - """ - - last_room_id = progress.get("last_room_id", "") - - def _background_remove_left_rooms_txn(txn): - # get a batch of room ids to consider - sql = """ - SELECT DISTINCT room_id FROM current_state_events - WHERE room_id > ? ORDER BY room_id LIMIT ? - """ - - txn.execute(sql, (last_room_id, batch_size)) - room_ids = [row[0] for row in txn] - if not room_ids: - return True, set() - - ########################################################################### - # - # exclude rooms where we have active members - - sql = """ - SELECT room_id - FROM local_current_membership - WHERE - room_id > ? AND room_id <= ? - AND membership = 'join' - GROUP BY room_id - """ - - txn.execute(sql, (last_room_id, room_ids[-1])) - joined_room_ids = {row[0] for row in txn} - to_delete = set(room_ids) - joined_room_ids - - ########################################################################### - # - # exclude rooms which we are in the process of constructing; these otherwise - # qualify as "rooms with no local users", and would have their - # forward extremities cleaned up. - - # the following query will return a list of rooms which have forward - # extremities that are *not* also the create event in the room - ie - # those that are not being created currently. - - sql = """ - SELECT DISTINCT efe.room_id - FROM event_forward_extremities efe - LEFT JOIN current_state_events cse ON - cse.event_id = efe.event_id - AND cse.type = 'm.room.create' - AND cse.state_key = '' - WHERE - cse.event_id IS NULL - AND efe.room_id > ? AND efe.room_id <= ? - """ - - txn.execute(sql, (last_room_id, room_ids[-1])) - - # build a set of those rooms within `to_delete` that do not appear in - # the above, leaving us with the rooms in `to_delete` that *are* being - # created. - creating_rooms = to_delete.difference(row[0] for row in txn) - logger.info("skipping rooms which are being created: %s", creating_rooms) - - # now remove the rooms being created from the list of those to delete. - # - # (we could have just taken the intersection of `to_delete` with the result - # of the sql query, but it's useful to be able to log `creating_rooms`; and - # having done so, it's quicker to remove the (few) creating rooms from - # `to_delete` than it is to form the intersection with the (larger) list of - # not-creating-rooms) - - to_delete -= creating_rooms - - ########################################################################### - # - # now clear the state for the rooms - - logger.info("Deleting current state left rooms: %r", to_delete) - - # First we get all users that we still think were joined to the - # room. This is so that we can mark those device lists as - # potentially stale, since there may have been a period where the - # server didn't share a room with the remote user and therefore may - # have missed any device updates. - rows = self.db.simple_select_many_txn( - txn, - table="current_state_events", - column="room_id", - iterable=to_delete, - keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN}, - retcols=("state_key",), - ) - - potentially_left_users = {row["state_key"] for row in rows} - - # Now lets actually delete the rooms from the DB. - self.db.simple_delete_many_txn( - txn, - table="current_state_events", - column="room_id", - iterable=to_delete, - keyvalues={}, - ) - - self.db.simple_delete_many_txn( - txn, - table="event_forward_extremities", - column="room_id", - iterable=to_delete, - keyvalues={}, - ) - - self.db.updates._background_update_progress_txn( - txn, - self.DELETE_CURRENT_STATE_UPDATE_NAME, - {"last_room_id": room_ids[-1]}, - ) - - return False, potentially_left_users - - finished, potentially_left_users = await self.db.runInteraction( - "_background_remove_left_rooms", _background_remove_left_rooms_txn - ) - - if finished: - await self.db.updates._end_background_update( - self.DELETE_CURRENT_STATE_UPDATE_NAME - ) - - # Now go and check if we still share a room with the remote users in - # the deleted rooms. If not mark their device lists as stale. - joined_users = await self.get_users_server_still_shares_room_with( - potentially_left_users - ) - - for user_id in potentially_left_users - joined_users: - await self.mark_remote_user_device_list_as_unsubscribed(user_id) - - return batch_size - - -class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): - """ Keeps track of the state at a given event. - - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. - """ - - def __init__(self, database: Database, db_conn, hs): - super(StateStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py deleted file mode 100644 index 725e12507f..0000000000 --- a/synapse/storage/data_stores/main/state_deltas.py +++ /dev/null @@ -1,121 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore - -logger = logging.getLogger(__name__) - - -class StateDeltasStore(SQLBaseStore): - def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): - """Fetch a list of room state changes since the given stream id - - Each entry in the result contains the following fields: - - stream_id (int) - - room_id (str) - - type (str): event type - - state_key (str): - - event_id (str|None): new event_id for this state key. None if the - state has been deleted. - - prev_event_id (str|None): previous event_id for this state key. None - if it's new state. - - Args: - prev_stream_id (int): point to get changes since (exclusive) - max_stream_id (int): the point that we know has been correctly persisted - - ie, an upper limit to return changes from. - - Returns: - Deferred[tuple[int, list[dict]]: A tuple consisting of: - - the stream id which these results go up to - - list of current_state_delta_stream rows. If it is empty, we are - up to date. - """ - prev_stream_id = int(prev_stream_id) - - # check we're not going backwards - assert prev_stream_id <= max_stream_id - - if not self._curr_state_delta_stream_cache.has_any_entity_changed( - prev_stream_id - ): - # if the CSDs haven't changed between prev_stream_id and now, we - # know for certain that they haven't changed between prev_stream_id and - # max_stream_id. - return defer.succeed((max_stream_id, [])) - - def get_current_state_deltas_txn(txn): - # First we calculate the max stream id that will give us less than - # N results. - # We arbitarily limit to 100 stream_id entries to ensure we don't - # select toooo many. - sql = """ - SELECT stream_id, count(*) - FROM current_state_delta_stream - WHERE stream_id > ? AND stream_id <= ? - GROUP BY stream_id - ORDER BY stream_id ASC - LIMIT 100 - """ - txn.execute(sql, (prev_stream_id, max_stream_id)) - - total = 0 - - for stream_id, count in txn: - total += count - if total > 100: - # We arbitarily limit to 100 entries to ensure we don't - # select toooo many. - logger.debug( - "Clipping current_state_delta_stream rows to stream_id %i", - stream_id, - ) - clipped_stream_id = stream_id - break - else: - # if there's no problem, we may as well go right up to the max_stream_id - clipped_stream_id = max_stream_id - - # Now actually get the deltas - sql = """ - SELECT stream_id, room_id, type, state_key, event_id, prev_event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - """ - txn.execute(sql, (prev_stream_id, clipped_stream_id)) - return clipped_stream_id, self.db.cursor_to_dict(txn) - - return self.db.runInteraction( - "get_current_state_deltas", get_current_state_deltas_txn - ) - - def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self.db.simple_select_one_onecol_txn( - txn, - table="current_state_delta_stream", - keyvalues={}, - retcol="COALESCE(MAX(stream_id), -1)", - ) - - def get_max_stream_id_in_current_state_deltas(self): - return self.db.runInteraction( - "get_max_stream_id_in_current_state_deltas", - self._get_max_stream_id_in_current_state_deltas_txn, - ) diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py deleted file mode 100644 index 40db8f594e..0000000000 --- a/synapse/storage/data_stores/main/stats.py +++ /dev/null @@ -1,878 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018, 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from itertools import chain -from typing import Tuple - -from twisted.internet.defer import DeferredLock - -from synapse.api.constants import EventTypes, Membership -from synapse.storage.data_stores.main.state_deltas import StateDeltasStore -from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine -from synapse.util.caches.descriptors import cached - -logger = logging.getLogger(__name__) - -# these fields track absolutes (e.g. total number of rooms on the server) -# You can think of these as Prometheus Gauges. -# You can draw these stats on a line graph. -# Example: number of users in a room -ABSOLUTE_STATS_FIELDS = { - "room": ( - "current_state_events", - "joined_members", - "invited_members", - "left_members", - "banned_members", - "local_users_in_room", - ), - "user": ("joined_rooms",), -} - -# these fields are per-timeslice and so should be reset to 0 upon a new slice -# You can draw these stats on a histogram. -# Example: number of events sent locally during a time slice -PER_SLICE_FIELDS = { - "room": ("total_events", "total_event_bytes"), - "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"), -} - -TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} - -# these are the tables (& ID columns) which contain our actual subjects -TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} - - -class StatsStore(StateDeltasStore): - def __init__(self, database: Database, db_conn, hs): - super(StatsStore, self).__init__(database, db_conn, hs) - - self.server_name = hs.hostname - self.clock = self.hs.get_clock() - self.stats_enabled = hs.config.stats_enabled - self.stats_bucket_size = hs.config.stats_bucket_size - - self.stats_delta_processing_lock = DeferredLock() - - self.db.updates.register_background_update_handler( - "populate_stats_process_rooms", self._populate_stats_process_rooms - ) - self.db.updates.register_background_update_handler( - "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2 - ) - self.db.updates.register_background_update_handler( - "populate_stats_process_users", self._populate_stats_process_users - ) - # we no longer need to perform clean-up, but we will give ourselves - # the potential to reintroduce it in the future – so documentation - # will still encourage the use of this no-op handler. - self.db.updates.register_noop_background_update("populate_stats_cleanup") - self.db.updates.register_noop_background_update("populate_stats_prepare") - - def quantise_stats_time(self, ts): - """ - Quantises a timestamp to be a multiple of the bucket size. - - Args: - ts (int): the timestamp to quantise, in milliseconds since the Unix - Epoch - - Returns: - int: a timestamp which - - is divisible by the bucket size; - - is no later than `ts`; and - - is the largest such timestamp. - """ - return (ts // self.stats_bucket_size) * self.stats_bucket_size - - async def _populate_stats_process_users(self, progress, batch_size): - """ - This is a background update which regenerates statistics for users. - """ - if not self.stats_enabled: - await self.db.updates._end_background_update("populate_stats_process_users") - return 1 - - last_user_id = progress.get("last_user_id", "") - - def _get_next_batch(txn): - sql = """ - SELECT DISTINCT name FROM users - WHERE name > ? - ORDER BY name ASC - LIMIT ? - """ - txn.execute(sql, (last_user_id, batch_size)) - return [r for r, in txn] - - users_to_work_on = await self.db.runInteraction( - "_populate_stats_process_users", _get_next_batch - ) - - # No more rooms -- complete the transaction. - if not users_to_work_on: - await self.db.updates._end_background_update("populate_stats_process_users") - return 1 - - for user_id in users_to_work_on: - await self._calculate_and_set_initial_state_for_user(user_id) - progress["last_user_id"] = user_id - - await self.db.runInteraction( - "populate_stats_process_users", - self.db.updates._background_update_progress_txn, - "populate_stats_process_users", - progress, - ) - - return len(users_to_work_on) - - async def _populate_stats_process_rooms(self, progress, batch_size): - """ - This was a background update which regenerated statistics for rooms. - - It has been replaced by StatsStore._populate_stats_process_rooms_2. This background - job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure - someone upgrading from ? - ORDER BY room_id ASC - LIMIT ? - """ - txn.execute(sql, (last_room_id, batch_size)) - return [r for r, in txn] - - rooms_to_work_on = await self.db.runInteraction( - "populate_stats_rooms_2_get_batch", _get_next_batch - ) - - # No more rooms -- complete the transaction. - if not rooms_to_work_on: - await self.db.updates._end_background_update( - "populate_stats_process_rooms_2" - ) - return 1 - - for room_id in rooms_to_work_on: - await self._calculate_and_set_initial_state_for_room(room_id) - progress["last_room_id"] = room_id - - await self.db.runInteraction( - "_populate_stats_process_rooms_2", - self.db.updates._background_update_progress_txn, - "populate_stats_process_rooms_2", - progress, - ) - - return len(rooms_to_work_on) - - def get_stats_positions(self): - """ - Returns the stats processor positions. - """ - return self.db.simple_select_one_onecol( - table="stats_incremental_position", - keyvalues={}, - retcol="stream_id", - desc="stats_incremental_position", - ) - - def update_room_state(self, room_id, fields): - """ - Args: - room_id (str) - fields (dict[str:Any]) - """ - - # For whatever reason some of the fields may contain null bytes, which - # postgres isn't a fan of, so we replace those fields with null. - for col in ( - "join_rules", - "history_visibility", - "encryption", - "name", - "topic", - "avatar", - "canonical_alias", - ): - field = fields.get(col) - if field and "\0" in field: - fields[col] = None - - return self.db.simple_upsert( - table="room_stats_state", - keyvalues={"room_id": room_id}, - values=fields, - desc="update_room_state", - ) - - def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): - """ - Get statistics for a given subject. - - Args: - stats_type (str): The type of subject - stats_id (str): The ID of the subject (e.g. room_id or user_id) - start (int): Pagination start. Number of entries, not timestamp. - size (int): How many entries to return. - - Returns: - Deferred[list[dict]], where the dict has the keys of - ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". - """ - return self.db.runInteraction( - "get_statistics_for_subject", - self._get_statistics_for_subject_txn, - stats_type, - stats_id, - start, - size, - ) - - def _get_statistics_for_subject_txn( - self, txn, stats_type, stats_id, start, size=100 - ): - """ - Transaction-bound version of L{get_statistics_for_subject}. - """ - - table, id_col = TYPE_TO_TABLE[stats_type] - selected_columns = list( - ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] - ) - - slice_list = self.db.simple_select_list_paginate_txn( - txn, - table + "_historical", - "end_ts", - start, - size, - retcols=selected_columns + ["bucket_size", "end_ts"], - keyvalues={id_col: stats_id}, - order_direction="DESC", - ) - - return slice_list - - @cached() - def get_earliest_token_for_stats(self, stats_type, id): - """ - Fetch the "earliest token". This is used by the room stats delta - processor to ignore deltas that have been processed between the - start of the background task and any particular room's stats - being calculated. - - Returns: - Deferred[int] - """ - table, id_col = TYPE_TO_TABLE[stats_type] - - return self.db.simple_select_one_onecol( - "%s_current" % (table,), - keyvalues={id_col: id}, - retcol="completed_delta_stream_id", - allow_none=True, - ) - - def bulk_update_stats_delta(self, ts, updates, stream_id): - """Bulk update stats tables for a given stream_id and updates the stats - incremental position. - - Args: - ts (int): Current timestamp in ms - updates(dict[str, dict[str, dict[str, Counter]]]): The updates to - commit as a mapping stats_type -> stats_id -> field -> delta. - stream_id (int): Current position. - - Returns: - Deferred - """ - - def _bulk_update_stats_delta_txn(txn): - for stats_type, stats_updates in updates.items(): - for stats_id, fields in stats_updates.items(): - logger.debug( - "Updating %s stats for %s: %s", stats_type, stats_id, fields - ) - self._update_stats_delta_txn( - txn, - ts=ts, - stats_type=stats_type, - stats_id=stats_id, - fields=fields, - complete_with_stream_id=stream_id, - ) - - self.db.simple_update_one_txn( - txn, - table="stats_incremental_position", - keyvalues={}, - updatevalues={"stream_id": stream_id}, - ) - - return self.db.runInteraction( - "bulk_update_stats_delta", _bulk_update_stats_delta_txn - ) - - def update_stats_delta( - self, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id, - absolute_field_overrides=None, - ): - """ - Updates the statistics for a subject, with a delta (difference/relative - change). - - Args: - ts (int): timestamp of the change - stats_type (str): "room" or "user" – the kind of subject - stats_id (str): the subject's ID (room ID or user ID) - fields (dict[str, int]): Deltas of stats values. - complete_with_stream_id (int, optional): - If supplied, converts an incomplete row into a complete row, - with the supplied stream_id marked as the stream_id where the - row was completed. - absolute_field_overrides (dict[str, int]): Current stats values - (i.e. not deltas) of absolute fields. - Does not work with per-slice fields. - """ - - return self.db.runInteraction( - "update_stats_delta", - self._update_stats_delta_txn, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id=complete_with_stream_id, - absolute_field_overrides=absolute_field_overrides, - ) - - def _update_stats_delta_txn( - self, - txn, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id, - absolute_field_overrides=None, - ): - if absolute_field_overrides is None: - absolute_field_overrides = {} - - table, id_col = TYPE_TO_TABLE[stats_type] - - quantised_ts = self.quantise_stats_time(int(ts)) - end_ts = quantised_ts + self.stats_bucket_size - - # Lets be paranoid and check that all the given field names are known - abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type] - slice_field_names = PER_SLICE_FIELDS[stats_type] - for field in chain(fields.keys(), absolute_field_overrides.keys()): - if field not in abs_field_names and field not in slice_field_names: - # guard against potential SQL injection dodginess - raise ValueError( - "%s is not a recognised field" - " for stats type %s" % (field, stats_type) - ) - - # Per slice fields do not get added to the _current table - - # This calculates the deltas (`field = field + ?` values) - # for absolute fields, - # * defaulting to 0 if not specified - # (required for the INSERT part of upserting to work) - # * omitting overrides specified in `absolute_field_overrides` - deltas_of_absolute_fields = { - key: fields.get(key, 0) - for key in abs_field_names - if key not in absolute_field_overrides - } - - # Keep the delta stream ID field up to date - absolute_field_overrides = absolute_field_overrides.copy() - absolute_field_overrides["completed_delta_stream_id"] = complete_with_stream_id - - # first upsert the `_current` table - self._upsert_with_additive_relatives_txn( - txn=txn, - table=table + "_current", - keyvalues={id_col: stats_id}, - absolutes=absolute_field_overrides, - additive_relatives=deltas_of_absolute_fields, - ) - - per_slice_additive_relatives = { - key: fields.get(key, 0) for key in slice_field_names - } - self._upsert_copy_from_table_with_additive_relatives_txn( - txn=txn, - into_table=table + "_historical", - keyvalues={id_col: stats_id}, - extra_dst_insvalues={"bucket_size": self.stats_bucket_size}, - extra_dst_keyvalues={"end_ts": end_ts}, - additive_relatives=per_slice_additive_relatives, - src_table=table + "_current", - copy_columns=abs_field_names, - ) - - def _upsert_with_additive_relatives_txn( - self, txn, table, keyvalues, absolutes, additive_relatives - ): - """Used to update values in the stats tables. - - This is basically a slightly convoluted upsert that *adds* to any - existing rows. - - Args: - txn - table (str): Table name - keyvalues (dict[str, any]): Row-identifying key values - absolutes (dict[str, any]): Absolute (set) fields - additive_relatives (dict[str, int]): Fields that will be added onto - if existing row present. - """ - if self.database_engine.can_native_upsert: - absolute_updates = [ - "%(field)s = EXCLUDED.%(field)s" % {"field": field} - for field in absolutes.keys() - ] - - relative_updates = [ - "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s" - % {"table": table, "field": field} - for field in additive_relatives.keys() - ] - - insert_cols = [] - qargs = [] - - for (key, val) in chain( - keyvalues.items(), absolutes.items(), additive_relatives.items() - ): - insert_cols.append(key) - qargs.append(val) - - sql = """ - INSERT INTO %(table)s (%(insert_cols_cs)s) - VALUES (%(insert_vals_qs)s) - ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s - """ % { - "table": table, - "insert_cols_cs": ", ".join(insert_cols), - "insert_vals_qs": ", ".join( - ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives)) - ), - "key_columns": ", ".join(keyvalues), - "updates": ", ".join(chain(absolute_updates, relative_updates)), - } - - txn.execute(sql, qargs) - else: - self.database_engine.lock_table(txn, table) - retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self.db.simple_select_one_txn( - txn, table, keyvalues, retcols, allow_none=True - ) - if current_row is None: - merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self.db.simple_insert_txn(txn, table, merged_dict) - else: - for (key, val) in additive_relatives.items(): - current_row[key] += val - current_row.update(absolutes) - self.db.simple_update_one_txn(txn, table, keyvalues, current_row) - - def _upsert_copy_from_table_with_additive_relatives_txn( - self, - txn, - into_table, - keyvalues, - extra_dst_keyvalues, - extra_dst_insvalues, - additive_relatives, - src_table, - copy_columns, - ): - """Updates the historic stats table with latest updates. - - This involves copying "absolute" fields from the `_current` table, and - adding relative fields to any existing values. - - Args: - txn: Transaction - into_table (str): The destination table to UPSERT the row into - keyvalues (dict[str, any]): Row-identifying key values - extra_dst_keyvalues (dict[str, any]): Additional keyvalues - for `into_table`. - extra_dst_insvalues (dict[str, any]): Additional values to insert - on new row creation for `into_table`. - additive_relatives (dict[str, any]): Fields that will be added onto - if existing row present. (Must be disjoint from copy_columns.) - src_table (str): The source table to copy from - copy_columns (iterable[str]): The list of columns to copy - """ - if self.database_engine.can_native_upsert: - ins_columns = chain( - keyvalues, - copy_columns, - additive_relatives, - extra_dst_keyvalues, - extra_dst_insvalues, - ) - sel_exprs = chain( - keyvalues, - copy_columns, - ( - "?" - for _ in chain( - additive_relatives, extra_dst_keyvalues, extra_dst_insvalues - ) - ), - ) - keyvalues_where = ("%s = ?" % f for f in keyvalues) - - sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns) - sets_ar = ( - "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f) - for f in additive_relatives - ) - - sql = """ - INSERT INTO %(into_table)s (%(ins_columns)s) - SELECT %(sel_exprs)s - FROM %(src_table)s - WHERE %(keyvalues_where)s - ON CONFLICT (%(keyvalues)s) - DO UPDATE SET %(sets)s - """ % { - "into_table": into_table, - "ins_columns": ", ".join(ins_columns), - "sel_exprs": ", ".join(sel_exprs), - "keyvalues_where": " AND ".join(keyvalues_where), - "src_table": src_table, - "keyvalues": ", ".join( - chain(keyvalues.keys(), extra_dst_keyvalues.keys()) - ), - "sets": ", ".join(chain(sets_cc, sets_ar)), - } - - qargs = list( - chain( - additive_relatives.values(), - extra_dst_keyvalues.values(), - extra_dst_insvalues.values(), - keyvalues.values(), - ) - ) - txn.execute(sql, qargs) - else: - self.database_engine.lock_table(txn, into_table) - src_row = self.db.simple_select_one_txn( - txn, src_table, keyvalues, copy_columns - ) - all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self.db.simple_select_one_txn( - txn, - into_table, - keyvalues=all_dest_keyvalues, - retcols=list(chain(additive_relatives.keys(), copy_columns)), - allow_none=True, - ) - - if dest_current_row is None: - merged_dict = { - **keyvalues, - **extra_dst_keyvalues, - **extra_dst_insvalues, - **src_row, - **additive_relatives, - } - self.db.simple_insert_txn(txn, into_table, merged_dict) - else: - for (key, val) in additive_relatives.items(): - src_row[key] = dest_current_row[key] + val - self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) - - def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): - """Fetches the counts of events in the given range of stream IDs. - - Args: - min_pos (int) - max_pos (int) - - Returns: - Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field - changes. - """ - - return self.db.runInteraction( - "stats_incremental_total_events_and_bytes", - self.get_changes_room_total_events_and_bytes_txn, - min_pos, - max_pos, - ) - - def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos): - """Gets the total_events and total_event_bytes counts for rooms and - senders, in a range of stream_orderings (including backfilled events). - - Args: - txn - low_pos (int): Low stream ordering - high_pos (int): High stream ordering - - Returns: - tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The - room and user deltas for total_events/total_event_bytes in the - format of `stats_id` -> fields - """ - - if low_pos >= high_pos: - # nothing to do here. - return {}, {} - - if isinstance(self.database_engine, PostgresEngine): - new_bytes_expression = "OCTET_LENGTH(json)" - else: - new_bytes_expression = "LENGTH(CAST(json AS BLOB))" - - sql = """ - SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes - FROM events INNER JOIN event_json USING (event_id) - WHERE (? < stream_ordering AND stream_ordering <= ?) - OR (? <= stream_ordering AND stream_ordering <= ?) - GROUP BY events.room_id - """ % ( - new_bytes_expression, - ) - - txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) - - room_deltas = { - room_id: {"total_events": new_events, "total_event_bytes": new_bytes} - for room_id, new_events, new_bytes in txn - } - - sql = """ - SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes - FROM events INNER JOIN event_json USING (event_id) - WHERE (? < stream_ordering AND stream_ordering <= ?) - OR (? <= stream_ordering AND stream_ordering <= ?) - GROUP BY events.sender - """ % ( - new_bytes_expression, - ) - - txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) - - user_deltas = { - user_id: {"total_events": new_events, "total_event_bytes": new_bytes} - for user_id, new_events, new_bytes in txn - if self.hs.is_mine_id(user_id) - } - - return room_deltas, user_deltas - - async def _calculate_and_set_initial_state_for_room( - self, room_id: str - ) -> Tuple[dict, dict, int]: - """Calculate and insert an entry into room_stats_current. - - Args: - room_id: The room ID under calculation. - - Returns: - A tuple of room state, membership counts and stream position. - """ - - def _fetch_current_state_stats(txn): - pos = self.get_room_max_stream_ordering() - - rows = self.db.simple_select_many_txn( - txn, - table="current_state_events", - column="type", - iterable=[ - EventTypes.Create, - EventTypes.JoinRules, - EventTypes.RoomHistoryVisibility, - EventTypes.RoomEncryption, - EventTypes.Name, - EventTypes.Topic, - EventTypes.RoomAvatar, - EventTypes.CanonicalAlias, - ], - keyvalues={"room_id": room_id, "state_key": ""}, - retcols=["event_id"], - ) - - event_ids = [row["event_id"] for row in rows] - - txn.execute( - """ - SELECT membership, count(*) FROM current_state_events - WHERE room_id = ? AND type = 'm.room.member' - GROUP BY membership - """, - (room_id,), - ) - membership_counts = {membership: cnt for membership, cnt in txn} - - txn.execute( - """ - SELECT COALESCE(count(*), 0) FROM current_state_events - WHERE room_id = ? - """, - (room_id,), - ) - - (current_state_events_count,) = txn.fetchone() - - users_in_room = self.get_users_in_room_txn(txn, room_id) - - return ( - event_ids, - membership_counts, - current_state_events_count, - users_in_room, - pos, - ) - - ( - event_ids, - membership_counts, - current_state_events_count, - users_in_room, - pos, - ) = await self.db.runInteraction( - "get_initial_state_for_room", _fetch_current_state_stats - ) - - state_event_map = await self.get_events(event_ids, get_prev_content=False) - - room_state = { - "join_rules": None, - "history_visibility": None, - "encryption": None, - "name": None, - "topic": None, - "avatar": None, - "canonical_alias": None, - "is_federatable": True, - } - - for event in state_event_map.values(): - if event.type == EventTypes.JoinRules: - room_state["join_rules"] = event.content.get("join_rule") - elif event.type == EventTypes.RoomHistoryVisibility: - room_state["history_visibility"] = event.content.get( - "history_visibility" - ) - elif event.type == EventTypes.RoomEncryption: - room_state["encryption"] = event.content.get("algorithm") - elif event.type == EventTypes.Name: - room_state["name"] = event.content.get("name") - elif event.type == EventTypes.Topic: - room_state["topic"] = event.content.get("topic") - elif event.type == EventTypes.RoomAvatar: - room_state["avatar"] = event.content.get("url") - elif event.type == EventTypes.CanonicalAlias: - room_state["canonical_alias"] = event.content.get("alias") - elif event.type == EventTypes.Create: - room_state["is_federatable"] = ( - event.content.get("m.federate", True) is True - ) - - await self.update_room_state(room_id, room_state) - - local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] - - await self.update_stats_delta( - ts=self.clock.time_msec(), - stats_type="room", - stats_id=room_id, - fields={}, - complete_with_stream_id=pos, - absolute_field_overrides={ - "current_state_events": current_state_events_count, - "joined_members": membership_counts.get(Membership.JOIN, 0), - "invited_members": membership_counts.get(Membership.INVITE, 0), - "left_members": membership_counts.get(Membership.LEAVE, 0), - "banned_members": membership_counts.get(Membership.BAN, 0), - "local_users_in_room": len(local_users_in_room), - }, - ) - - async def _calculate_and_set_initial_state_for_user(self, user_id): - def _calculate_and_set_initial_state_for_user_txn(txn): - pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) - - txn.execute( - """ - SELECT COUNT(distinct room_id) FROM current_state_events - WHERE type = 'm.room.member' AND state_key = ? - AND membership = 'join' - """, - (user_id,), - ) - (count,) = txn.fetchone() - return count, pos - - joined_rooms, pos = await self.db.runInteraction( - "calculate_and_set_initial_state_for_user", - _calculate_and_set_initial_state_for_user_txn, - ) - - await self.update_stats_delta( - ts=self.clock.time_msec(), - stats_type="user", - stats_id=user_id, - fields={}, - complete_with_stream_id=pos, - absolute_field_overrides={"joined_rooms": joined_rooms}, - ) diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py deleted file mode 100644 index f1334a6efc..0000000000 --- a/synapse/storage/data_stores/main/stream.py +++ /dev/null @@ -1,1064 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" This module is responsible for getting events from the DB for pagination -and event streaming. - -The order it returns events in depend on whether we are streaming forwards or -are paginating backwards. We do this because we want to handle out of order -messages nicely, while still returning them in the correct order when we -paginate bacwards. - -This is implemented by keeping two ordering columns: stream_ordering and -topological_ordering. Stream ordering is basically insertion/received order -(except for events from backfill requests). The topological_ordering is a -weak ordering of events based on the pdu graph. - -This means that we have to have two different types of tokens, depending on -what sort order was used: - - stream tokens are of the form: "s%d", which maps directly to the column - - topological tokems: "t%d-%d", where the integers map to the topological - and stream ordering columns respectively. -""" - -import abc -import logging -from collections import namedtuple -from typing import Optional - -from twisted.internet import defer - -from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database, make_in_list_sql_clause -from synapse.storage.engines import PostgresEngine -from synapse.types import RoomStreamToken -from synapse.util.caches.stream_change_cache import StreamChangeCache - -logger = logging.getLogger(__name__) - - -MAX_STREAM_SIZE = 1000 - - -_STREAM_TOKEN = "stream" -_TOPOLOGICAL_TOKEN = "topological" - - -# Used as return values for pagination APIs -_EventDictReturn = namedtuple( - "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") -) - - -def generate_pagination_where_clause( - direction, column_names, from_token, to_token, engine -): - """Creates an SQL expression to bound the columns by the pagination - tokens. - - For example creates an SQL expression like: - - (6, 7) >= (topological_ordering, stream_ordering) - AND (5, 3) < (topological_ordering, stream_ordering) - - would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). - - Note that tokens are considered to be after the row they are in, e.g. if - a row A has a token T, then we consider A to be before T. This convention - is important when figuring out inequalities for the generated SQL, and - produces the following result: - - If paginating forwards then we exclude any rows matching the from - token, but include those that match the to token. - - If paginating backwards then we include any rows matching the from - token, but include those that match the to token. - - Args: - direction (str): Whether we're paginating backwards("b") or - forwards ("f"). - column_names (tuple[str, str]): The column names to bound. Must *not* - be user defined as these get inserted directly into the SQL - statement without escapes. - from_token (tuple[int, int]|None): The start point for the pagination. - This is an exclusive minimum bound if direction is "f", and an - inclusive maximum bound if direction is "b". - to_token (tuple[int, int]|None): The endpoint point for the pagination. - This is an inclusive maximum bound if direction is "f", and an - exclusive minimum bound if direction is "b". - engine: The database engine to generate the clauses for - - Returns: - str: The sql expression - """ - assert direction in ("b", "f") - - where_clause = [] - if from_token: - where_clause.append( - _make_generic_sql_bound( - bound=">=" if direction == "b" else "<", - column_names=column_names, - values=from_token, - engine=engine, - ) - ) - - if to_token: - where_clause.append( - _make_generic_sql_bound( - bound="<" if direction == "b" else ">=", - column_names=column_names, - values=to_token, - engine=engine, - ) - ) - - return " AND ".join(where_clause) - - -def _make_generic_sql_bound(bound, column_names, values, engine): - """Create an SQL expression that bounds the given column names by the - values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. - - Only works with two columns. - - Older versions of SQLite don't support that syntax so we have to expand it - out manually. - - Args: - bound (str): The comparison operator to use. One of ">", "<", ">=", - "<=", where the values are on the left and columns on the right. - names (tuple[str, str]): The column names. Must *not* be user defined - as these get inserted directly into the SQL statement without - escapes. - values (tuple[int|None, int]): The values to bound the columns by. If - the first value is None then only creates a bound on the second - column. - engine: The database engine to generate the SQL for - - Returns: - str - """ - - assert bound in (">", "<", ">=", "<=") - - name1, name2 = column_names - val1, val2 = values - - if val1 is None: - val2 = int(val2) - return "(%d %s %s)" % (val2, bound, name2) - - val1 = int(val1) - val2 = int(val2) - - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x < a) OR (x=a AND y ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] - return rows - - rows = yield self.db.runInteraction("get_room_events_stream_for_room", f) - - ret = yield self.get_events_as_list( - [r.event_id for r in rows], get_prev_content=True - ) - - self._set_before_and_after(ret, rows, topo_order=from_id is None) - - if order.lower() == "desc": - ret.reverse() - - if rows: - key = "s%d" % min(r.stream_ordering for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = from_key - - return ret, key - - @defer.inlineCallbacks - def get_membership_changes_for_user(self, user_id, from_key, to_key): - from_id = RoomStreamToken.parse_stream_token(from_key).stream - to_id = RoomStreamToken.parse_stream_token(to_key).stream - - if from_key == to_key: - return [] - - if from_id: - has_changed = self._membership_stream_cache.has_entity_changed( - user_id, int(from_id) - ) - if not has_changed: - return [] - - def f(txn): - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - ) - txn.execute(sql, (user_id, from_id, to_id)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] - - return rows - - rows = yield self.db.runInteraction("get_membership_changes_for_user", f) - - ret = yield self.get_events_as_list( - [r.event_id for r in rows], get_prev_content=True - ) - - self._set_before_and_after(ret, rows, topo_order=False) - - return ret - - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token): - """Get the most recent events in the room in topological ordering. - - Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. - - Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of - events and a token pointing to the start of the returned - events. - The events returned are in ascending order. - """ - - rows, token = yield self.get_recent_event_ids_for_room( - room_id, limit, end_token - ) - - events = yield self.get_events_as_list( - [r.event_id for r in rows], get_prev_content=True - ) - - self._set_before_and_after(events, rows) - - return (events, token) - - @defer.inlineCallbacks - def get_recent_event_ids_for_room(self, room_id, limit, end_token): - """Get the most recent events in the room in topological ordering. - - Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. - - Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of - _EventDictReturn and a token pointing to the start of the returned - events. - The events returned are in ascending order. - """ - # Allow a zero limit here, and no-op. - if limit == 0: - return [], end_token - - end_token = RoomStreamToken.parse(end_token) - - rows, token = yield self.db.runInteraction( - "get_recent_event_ids_for_room", - self._paginate_room_events_txn, - room_id, - from_token=end_token, - limit=limit, - ) - - # We want to return the results in ascending order. - rows.reverse() - - return rows, token - - def get_room_event_before_stream_ordering(self, room_id, stream_ordering): - """Gets details of the first event in a room at or before a stream ordering - - Args: - room_id (str): - stream_ordering (int): - - Returns: - Deferred[(int, int, str)]: - (stream ordering, topological ordering, event_id) - """ - - def _f(txn): - sql = ( - "SELECT stream_ordering, topological_ordering, event_id" - " FROM events" - " WHERE room_id = ? AND stream_ordering <= ?" - " AND NOT outlier" - " ORDER BY stream_ordering DESC" - " LIMIT 1" - ) - txn.execute(sql, (room_id, stream_ordering)) - return txn.fetchone() - - return self.db.runInteraction("get_room_event_before_stream_ordering", _f) - - async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: - """Returns the current token for rooms stream. - - By default, it returns the current global stream token. Specifying a - `room_id` causes it to return the current room specific topological - token. - """ - token = self.get_room_max_stream_ordering() - if room_id is None: - return "s%d" % (token,) - else: - topo = await self.db.runInteraction( - "_get_max_topological_txn", self._get_max_topological_txn, room_id - ) - return "t%d-%d" % (topo, token) - - def get_stream_token_for_event(self, event_id): - """The stream token for an event - Args: - event_id(str): The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A deferred "s%d" stream token. - """ - return self.db.simple_select_one_onecol( - table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" - ).addCallback(lambda row: "s%d" % (row,)) - - def get_topological_token_for_event(self, event_id): - """The stream token for an event - Args: - event_id(str): The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A deferred "t%d-%d" topological token. - """ - return self.db.simple_select_one( - table="events", - keyvalues={"event_id": event_id}, - retcols=("stream_ordering", "topological_ordering"), - desc="get_topological_token_for_event", - ).addCallback( - lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) - ) - - def get_max_topological_token(self, room_id, stream_key): - """Get the max topological token in a room before the given stream - ordering. - - Args: - room_id (str) - stream_key (int) - - Returns: - Deferred[int] - """ - sql = ( - "SELECT coalesce(max(topological_ordering), 0) FROM events" - " WHERE room_id = ? AND stream_ordering < ?" - ) - return self.db.execute( - "get_max_topological_token", None, sql, room_id, stream_key - ).addCallback(lambda r: r[0][0] if r else 0) - - def _get_max_topological_txn(self, txn, room_id): - txn.execute( - "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", - (room_id,), - ) - - rows = txn.fetchall() - return rows[0][0] if rows else 0 - - @staticmethod - def _set_before_and_after(events, rows, topo_order=True): - """Inserts ordering information to events' internal metadata from - the DB rows. - - Args: - events (list[FrozenEvent]) - rows (list[_EventDictReturn]) - topo_order (bool): Whether the events were ordered topologically - or by stream ordering. If true then all rows should have a non - null topological_ordering. - """ - for event, row in zip(events, rows): - stream = row.stream_ordering - if topo_order and row.topological_ordering: - topo = row.topological_ordering - else: - topo = None - internal = event.internal_metadata - internal.before = str(RoomStreamToken(topo, stream - 1)) - internal.after = str(RoomStreamToken(topo, stream)) - internal.order = (int(topo) if topo else 0, int(stream)) - - @defer.inlineCallbacks - def get_events_around( - self, room_id, event_id, before_limit, after_limit, event_filter=None - ): - """Retrieve events and pagination tokens around a given event in a - room. - - Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) - - Returns: - dict - """ - - results = yield self.db.runInteraction( - "get_events_around", - self._get_events_around_txn, - room_id, - event_id, - before_limit, - after_limit, - event_filter, - ) - - events_before = yield self.get_events_as_list( - list(results["before"]["event_ids"]), get_prev_content=True - ) - - events_after = yield self.get_events_as_list( - list(results["after"]["event_ids"]), get_prev_content=True - ) - - return { - "events_before": events_before, - "events_after": events_after, - "start": results["before"]["token"], - "end": results["after"]["token"], - } - - def _get_events_around_txn( - self, txn, room_id, event_id, before_limit, after_limit, event_filter - ): - """Retrieves event_ids and pagination tokens around a given event in a - room. - - Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) - - Returns: - dict - """ - - results = self.db.simple_select_one_txn( - txn, - "events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcols=["stream_ordering", "topological_ordering"], - ) - - # Paginating backwards includes the event at the token, but paginating - # forward doesn't. - before_token = RoomStreamToken( - results["topological_ordering"] - 1, results["stream_ordering"] - ) - - after_token = RoomStreamToken( - results["topological_ordering"], results["stream_ordering"] - ) - - rows, start_token = self._paginate_room_events_txn( - txn, - room_id, - before_token, - direction="b", - limit=before_limit, - event_filter=event_filter, - ) - events_before = [r.event_id for r in rows] - - rows, end_token = self._paginate_room_events_txn( - txn, - room_id, - after_token, - direction="f", - limit=after_limit, - event_filter=event_filter, - ) - events_after = [r.event_id for r in rows] - - return { - "before": {"event_ids": events_before, "token": start_token}, - "after": {"event_ids": events_after, "token": end_token}, - } - - @defer.inlineCallbacks - def get_all_new_events_stream(self, from_id, current_id, limit): - """Get all new events - - Returns all events with from_id < stream_ordering <= current_id. - - Args: - from_id (int): the stream_ordering of the last event we processed - current_id (int): the stream_ordering of the most recently processed event - limit (int): the maximum number of events to return - - Returns: - Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where - `next_id` is the next value to pass as `from_id` (it will either be the - stream_ordering of the last returned event, or, if fewer than `limit` events - were found, `current_id`. - """ - - def get_all_new_events_stream_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id" - " FROM events AS e" - " WHERE" - " ? < e.stream_ordering AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - " LIMIT ?" - ) - - txn.execute(sql, (from_id, current_id, limit)) - rows = txn.fetchall() - - upper_bound = current_id - if len(rows) == limit: - upper_bound = rows[-1][0] - - return upper_bound, [row[1] for row in rows] - - upper_bound, event_ids = yield self.db.runInteraction( - "get_all_new_events_stream", get_all_new_events_stream_txn - ) - - events = yield self.get_events_as_list(event_ids) - - return upper_bound, events - - async def get_federation_out_pos(self, typ: str) -> int: - if self._need_to_reset_federation_stream_positions: - await self.db.runInteraction( - "_reset_federation_positions_txn", self._reset_federation_positions_txn - ) - self._need_to_reset_federation_stream_positions = False - - return await self.db.simple_select_one_onecol( - table="federation_stream_position", - retcol="stream_id", - keyvalues={"type": typ, "instance_name": self._instance_name}, - desc="get_federation_out_pos", - ) - - async def update_federation_out_pos(self, typ, stream_id): - if self._need_to_reset_federation_stream_positions: - await self.db.runInteraction( - "_reset_federation_positions_txn", self._reset_federation_positions_txn - ) - self._need_to_reset_federation_stream_positions = False - - return await self.db.simple_update_one( - table="federation_stream_position", - keyvalues={"type": typ, "instance_name": self._instance_name}, - updatevalues={"stream_id": stream_id}, - desc="update_federation_out_pos", - ) - - def _reset_federation_positions_txn(self, txn): - """Fiddles with the `federation_stream_position` table to make it match - the configured federation sender instances during start up. - """ - - # The federation sender instances may have changed, so we need to - # massage the `federation_stream_position` table to have a row per type - # per instance sending federation. If there is a mismatch we update the - # table with the correct rows using the *minimum* stream ID seen. This - # may result in resending of events/EDUs to remote servers, but that is - # preferable to dropping them. - - if not self._send_federation: - return - - # Pull out the configured instances. If we don't have a shard config then - # we assume that we're the only instance sending. - configured_instances = self._federation_shard_config.instances - if not configured_instances: - configured_instances = [self._instance_name] - elif self._instance_name not in configured_instances: - return - - instances_in_table = self.db.simple_select_onecol_txn( - txn, - table="federation_stream_position", - keyvalues={}, - retcol="instance_name", - ) - - if set(instances_in_table) == set(configured_instances): - # Nothing to do - return - - sql = """ - SELECT type, MIN(stream_id) FROM federation_stream_position - GROUP BY type - """ - txn.execute(sql) - min_positions = dict(txn) # Map from type -> min position - - # Ensure we do actually have some values here - assert set(min_positions) == {"federation", "events"} - - sql = """ - DELETE FROM federation_stream_position - WHERE NOT (%s) - """ - clause, args = make_in_list_sql_clause( - txn.database_engine, "instance_name", configured_instances - ) - txn.execute(sql % (clause,), args) - - for typ, stream_id in min_positions.items(): - self.db.simple_upsert_txn( - txn, - table="federation_stream_position", - keyvalues={"type": typ, "instance_name": self._instance_name}, - values={"stream_id": stream_id}, - ) - - def has_room_changed_since(self, room_id, stream_id): - return self._events_stream_cache.has_entity_changed(room_id, stream_id) - - def _paginate_room_events_txn( - self, - txn, - room_id, - from_token, - to_token=None, - direction="b", - limit=-1, - event_filter=None, - ): - """Returns list of events before or after a given token. - - Args: - txn - room_id (str) - from_token (RoomStreamToken): The token used to stream from - to_token (RoomStreamToken|None): A token which if given limits the - results to only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to - those that match the filter. - - Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns the results - as a list of _EventDictReturn and a token that points to the end - of the result set. If no events are returned then the end of the - stream has been reached (i.e. there are no events between - `from_token` and `to_token`), or `limit` is zero. - """ - - assert int(limit) >= 0 - - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. - args = [False, room_id] - if direction == "b": - order = "DESC" - else: - order = "ASC" - - bounds = generate_pagination_where_clause( - direction=direction, - column_names=("topological_ordering", "stream_ordering"), - from_token=from_token, - to_token=to_token, - engine=self.database_engine, - ) - - filter_clause, filter_args = filter_to_clause(event_filter) - - if filter_clause: - bounds += " AND " + filter_clause - args.extend(filter_args) - - args.append(int(limit)) - - select_keywords = "SELECT" - join_clause = "" - if event_filter and event_filter.labels: - # If we're not filtering on a label, then joining on event_labels will - # return as many row for a single event as the number of labels it has. To - # avoid this, only join if we're filtering on at least one label. - join_clause = """ - LEFT JOIN event_labels - USING (event_id, room_id, topological_ordering) - """ - if len(event_filter.labels) > 1: - # Using DISTINCT in this SELECT query is quite expensive, because it - # requires the engine to sort on the entire (not limited) result set, - # i.e. the entire events table. We only need to use it when we're - # filtering on more than two labels, because that's the only scenario - # in which we can possibly to get multiple times the same event ID in - # the results. - select_keywords += "DISTINCT" - - sql = """ - %(select_keywords)s event_id, topological_ordering, stream_ordering - FROM events - %(join_clause)s - WHERE outlier = ? AND room_id = ? AND %(bounds)s - ORDER BY topological_ordering %(order)s, - stream_ordering %(order)s LIMIT ? - """ % { - "select_keywords": select_keywords, - "join_clause": join_clause, - "bounds": bounds, - "order": order, - } - - txn.execute(sql, args) - - rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] - - if rows: - topo = rows[-1].topological_ordering - toke = rows[-1].stream_ordering - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - toke -= 1 - next_token = RoomStreamToken(topo, toke) - else: - # TODO (erikj): We should work out what to do here instead. - next_token = to_token if to_token else from_token - - return rows, str(next_token) - - @defer.inlineCallbacks - def paginate_room_events( - self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None - ): - """Returns list of events before or after a given token. - - Args: - room_id (str) - from_key (str): The token used to stream from - to_key (str|None): A token which if given limits the results to - only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to - those that match the filter. - - Returns: - tuple[list[FrozenEvent], str]: Returns the results as a list of - events and a token that points to the end of the result set. If no - events are returned then the end of the stream has been reached - (i.e. there are no events between `from_key` and `to_key`). - """ - - from_key = RoomStreamToken.parse(from_key) - if to_key: - to_key = RoomStreamToken.parse(to_key) - - rows, token = yield self.db.runInteraction( - "paginate_room_events", - self._paginate_room_events_txn, - room_id, - from_key, - to_key, - direction, - limit, - event_filter, - ) - - events = yield self.get_events_as_list( - [r.event_id for r in rows], get_prev_content=True - ) - - self._set_before_and_after(events, rows) - - return (events, token) - - -class StreamStore(StreamWorkerStore): - def get_room_max_stream_ordering(self): - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self): - return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py deleted file mode 100644 index bd7227773a..0000000000 --- a/synapse/storage/data_stores/main/tags.py +++ /dev/null @@ -1,288 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import List, Tuple - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.storage._base import db_to_json -from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore -from synapse.util.caches.descriptors import cached - -logger = logging.getLogger(__name__) - - -class TagsWorkerStore(AccountDataWorkerStore): - @cached() - def get_tags_for_user(self, user_id): - """Get all the tags for a user. - - - Args: - user_id(str): The user to get the tags for. - Returns: - A deferred dict mapping from room_id strings to dicts mapping from - tag strings to tag content. - """ - - deferred = self.db.simple_select_list( - "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] - ) - - @deferred.addCallback - def tags_by_room(rows): - tags_by_room = {} - for row in rows: - room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = db_to_json(row["content"]) - return tags_by_room - - return deferred - - async def get_all_updated_tags( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for tags replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - if last_id == current_id: - return [], current_id, False - - def get_all_updated_tags_txn(txn): - sql = ( - "SELECT stream_id, user_id, room_id" - " FROM room_tags_revisions as r" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() - - tag_ids = await self.db.runInteraction( - "get_all_updated_tags", get_all_updated_tags_txn - ) - - def get_tag_content(txn, tag_ids): - sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" - results = [] - for stream_id, user_id, room_id in tag_ids: - txn.execute(sql, (user_id, room_id)) - tags = [] - for tag, content in txn: - tags.append(json.dumps(tag) + ":" + content) - tag_json = "{" + ",".join(tags) + "}" - results.append((stream_id, (user_id, room_id, tag_json))) - - return results - - batch_size = 50 - results = [] - for i in range(0, len(tag_ids), batch_size): - tags = await self.db.runInteraction( - "get_all_updated_tag_content", - get_tag_content, - tag_ids[i : i + batch_size], - ) - results.extend(tags) - - limited = False - upto_token = current_id - if len(results) >= limit: - upto_token = results[-1][0] - limited = True - - return results, upto_token, limited - - @defer.inlineCallbacks - def get_updated_tags(self, user_id, stream_id): - """Get all the tags for the rooms where the tags have changed since the - given version - - Args: - user_id(str): The user to get the tags for. - stream_id(int): The earliest update to get for the user. - Returns: - A deferred dict mapping from room_id strings to lists of tag - strings for all the rooms that changed since the stream_id token. - """ - - def get_updated_tags_txn(txn): - sql = ( - "SELECT room_id from room_tags_revisions" - " WHERE user_id = ? AND stream_id > ?" - ) - txn.execute(sql, (user_id, stream_id)) - room_ids = [row[0] for row in txn] - return room_ids - - changed = self._account_data_stream_cache.has_entity_changed( - user_id, int(stream_id) - ) - if not changed: - return {} - - room_ids = yield self.db.runInteraction( - "get_updated_tags", get_updated_tags_txn - ) - - results = {} - if room_ids: - tags_by_room = yield self.get_tags_for_user(user_id) - for room_id in room_ids: - results[room_id] = tags_by_room.get(room_id, {}) - - return results - - def get_tags_for_room(self, user_id, room_id): - """Get all the tags for the given room - Args: - user_id(str): The user to get tags for - room_id(str): The room to get tags for - Returns: - A deferred list of string tags. - """ - return self.db.simple_select_list( - table="room_tags", - keyvalues={"user_id": user_id, "room_id": room_id}, - retcols=("tag", "content"), - desc="get_tags_for_room", - ).addCallback( - lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows} - ) - - -class TagsStore(TagsWorkerStore): - @defer.inlineCallbacks - def add_tag_to_room(self, user_id, room_id, tag, content): - """Add a tag to a room for a user. - Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - tag(str): The tag name to add. - content(dict): A json object to associate with the tag. - Returns: - A deferred that completes once the tag has been added. - """ - content_json = json.dumps(content) - - def add_tag_txn(txn, next_id): - self.db.simple_upsert_txn( - txn, - table="room_tags", - keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, - values={"content": content_json}, - ) - self._update_revision_txn(txn, user_id, room_id, next_id) - - with self._account_data_id_gen.get_next() as next_id: - yield self.db.runInteraction("add_tag", add_tag_txn, next_id) - - self.get_tags_for_user.invalidate((user_id,)) - - result = self._account_data_id_gen.get_current_token() - return result - - @defer.inlineCallbacks - def remove_tag_from_room(self, user_id, room_id, tag): - """Remove a tag from a room for a user. - Returns: - A deferred that completes once the tag has been removed - """ - - def remove_tag_txn(txn, next_id): - sql = ( - "DELETE FROM room_tags " - " WHERE user_id = ? AND room_id = ? AND tag = ?" - ) - txn.execute(sql, (user_id, room_id, tag)) - self._update_revision_txn(txn, user_id, room_id, next_id) - - with self._account_data_id_gen.get_next() as next_id: - yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id) - - self.get_tags_for_user.invalidate((user_id,)) - - result = self._account_data_id_gen.get_current_token() - return result - - def _update_revision_txn(self, txn, user_id, room_id, next_id): - """Update the latest revision of the tags for the given user and room. - - Args: - txn: The database cursor - user_id(str): The ID of the user. - room_id(str): The ID of the room. - next_id(int): The the revision to advance to. - """ - - txn.call_after( - self._account_data_stream_cache.entity_has_changed, user_id, next_id - ) - - # Note: This is only here for backwards compat to allow admins to - # roll back to a previous Synapse version. Next time we update the - # database version we can remove this table. - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" - ) - txn.execute(update_max_id_sql, (next_id, next_id)) - - update_sql = ( - "UPDATE room_tags_revisions" - " SET stream_id = ?" - " WHERE user_id = ?" - " AND room_id = ?" - ) - txn.execute(update_sql, (next_id, user_id, room_id)) - - if txn.rowcount == 0: - insert_sql = ( - "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)" - " VALUES (?, ?, ?)" - ) - try: - txn.execute(insert_sql, (user_id, room_id, next_id)) - except self.database_engine.module.IntegrityError: - # Ignore insertion errors. It doesn't matter if the row wasn't - # inserted because if two updates happend concurrently the one - # with the higher stream_id will not be reported to a client - # unless the previous update has completed. It doesn't matter - # which stream_id ends up in the table, as long as it is higher - # than the id that the client has. - pass diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py deleted file mode 100644 index a9bf457939..0000000000 --- a/synapse/storage/data_stores/main/transactions.py +++ /dev/null @@ -1,269 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from collections import namedtuple - -from canonicaljson import encode_canonical_json - -from twisted.internet import defer - -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import Database -from synapse.util.caches.expiringcache import ExpiringCache - -db_binary_type = memoryview - -logger = logging.getLogger(__name__) - - -_TransactionRow = namedtuple( - "_TransactionRow", - ("id", "transaction_id", "destination", "ts", "response_code", "response_json"), -) - -_UpdateTransactionRow = namedtuple( - "_TransactionRow", ("response_code", "response_json") -) - -SENTINEL = object() - - -class TransactionStore(SQLBaseStore): - """A collection of queries for handling PDUs. - """ - - def __init__(self, database: Database, db_conn, hs): - super(TransactionStore, self).__init__(database, db_conn, hs) - - self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) - - self._destination_retry_cache = ExpiringCache( - cache_name="get_destination_retry_timings", - clock=self._clock, - expiry_ms=5 * 60 * 1000, - ) - - def get_received_txn_response(self, transaction_id, origin): - """For an incoming transaction from a given origin, check if we have - already responded to it. If so, return the response code and response - body (as a dict). - - Args: - transaction_id (str) - origin(str) - - Returns: - tuple: None if we have not previously responded to - this transaction or a 2-tuple of (int, dict) - """ - - return self.db.runInteraction( - "get_received_txn_response", - self._get_received_txn_response, - transaction_id, - origin, - ) - - def _get_received_txn_response(self, txn, transaction_id, origin): - result = self.db.simple_select_one_txn( - txn, - table="received_transactions", - keyvalues={"transaction_id": transaction_id, "origin": origin}, - retcols=( - "transaction_id", - "origin", - "ts", - "response_code", - "response_json", - "has_been_referenced", - ), - allow_none=True, - ) - - if result and result["response_code"]: - return result["response_code"], db_to_json(result["response_json"]) - - else: - return None - - def set_received_txn_response(self, transaction_id, origin, code, response_dict): - """Persist the response we returened for an incoming transaction, and - should return for subsequent transactions with the same transaction_id - and origin. - - Args: - txn - transaction_id (str) - origin (str) - code (int) - response_json (str) - """ - - return self.db.simple_insert( - table="received_transactions", - values={ - "transaction_id": transaction_id, - "origin": origin, - "response_code": code, - "response_json": db_binary_type(encode_canonical_json(response_dict)), - "ts": self._clock.time_msec(), - }, - or_ignore=True, - desc="set_received_txn_response", - ) - - @defer.inlineCallbacks - def get_destination_retry_timings(self, destination): - """Gets the current retry timings (if any) for a given destination. - - Args: - destination (str) - - Returns: - None if not retrying - Otherwise a dict for the retry scheme - """ - - result = self._destination_retry_cache.get(destination, SENTINEL) - if result is not SENTINEL: - return result - - result = yield self.db.runInteraction( - "get_destination_retry_timings", - self._get_destination_retry_timings, - destination, - ) - - # We don't hugely care about race conditions between getting and - # invalidating the cache, since we time out fairly quickly anyway. - self._destination_retry_cache[destination] = result - return result - - def _get_destination_retry_timings(self, txn, destination): - result = self.db.simple_select_one_txn( - txn, - table="destinations", - keyvalues={"destination": destination}, - retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), - allow_none=True, - ) - - if result and result["retry_last_ts"] > 0: - return result - else: - return None - - def set_destination_retry_timings( - self, destination, failure_ts, retry_last_ts, retry_interval - ): - """Sets the current retry timings for a given destination. - Both timings should be zero if retrying is no longer occuring. - - Args: - destination (str) - failure_ts (int|None) - when the server started failing (ms since epoch) - retry_last_ts (int) - time of last retry attempt in unix epoch ms - retry_interval (int) - how long until next retry in ms - """ - - self._destination_retry_cache.pop(destination, None) - return self.db.runInteraction( - "set_destination_retry_timings", - self._set_destination_retry_timings, - destination, - failure_ts, - retry_last_ts, - retry_interval, - ) - - def _set_destination_retry_timings( - self, txn, destination, failure_ts, retry_last_ts, retry_interval - ): - - if self.database_engine.can_native_upsert: - # Upsert retry time interval if retry_interval is zero (i.e. we're - # resetting it) or greater than the existing retry interval. - - sql = """ - INSERT INTO destinations ( - destination, failure_ts, retry_last_ts, retry_interval - ) - VALUES (?, ?, ?, ?) - ON CONFLICT (destination) DO UPDATE SET - failure_ts = EXCLUDED.failure_ts, - retry_last_ts = EXCLUDED.retry_last_ts, - retry_interval = EXCLUDED.retry_interval - WHERE - EXCLUDED.retry_interval = 0 - OR destinations.retry_interval < EXCLUDED.retry_interval - """ - - txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) - - return - - self.database_engine.lock_table(txn, "destinations") - - # We need to be careful here as the data may have changed from under us - # due to a worker setting the timings. - - prev_row = self.db.simple_select_one_txn( - txn, - table="destinations", - keyvalues={"destination": destination}, - retcols=("failure_ts", "retry_last_ts", "retry_interval"), - allow_none=True, - ) - - if not prev_row: - self.db.simple_insert_txn( - txn, - table="destinations", - values={ - "destination": destination, - "failure_ts": failure_ts, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - }, - ) - elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self.db.simple_update_one_txn( - txn, - "destinations", - keyvalues={"destination": destination}, - updatevalues={ - "failure_ts": failure_ts, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - }, - ) - - def _start_cleanup_transactions(self): - return run_as_background_process( - "cleanup_transactions", self._cleanup_transactions - ) - - def _cleanup_transactions(self): - now = self._clock.time_msec() - month_ago = now - 30 * 24 * 60 * 60 * 1000 - - def _cleanup_transactions_txn(txn): - txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - - return self.db.runInteraction( - "_cleanup_transactions", _cleanup_transactions_txn - ) diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py deleted file mode 100644 index 5f1b919748..0000000000 --- a/synapse/storage/data_stores/main/ui_auth.py +++ /dev/null @@ -1,300 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, Optional, Union - -import attr -from canonicaljson import json - -from synapse.api.errors import StoreError -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.types import JsonDict -from synapse.util import stringutils as stringutils - - -@attr.s -class UIAuthSessionData: - session_id = attr.ib(type=str) - # The dictionary from the client root level, not the 'auth' key. - clientdict = attr.ib(type=JsonDict) - # The URI and method the session was intiatied with. These are checked at - # each stage of the authentication to ensure that the asked for operation - # has not changed. - uri = attr.ib(type=str) - method = attr.ib(type=str) - # A string description of the operation that the current authentication is - # authorising. - description = attr.ib(type=str) - - -class UIAuthWorkerStore(SQLBaseStore): - """ - Manage user interactive authentication sessions. - """ - - async def create_ui_auth_session( - self, clientdict: JsonDict, uri: str, method: str, description: str, - ) -> UIAuthSessionData: - """ - Creates a new user interactive authentication session. - - The session can be used to track the stages necessary to authenticate a - user across multiple HTTP requests. - - Args: - clientdict: - The dictionary from the client root level, not the 'auth' key. - uri: - The URI this session was initiated with, this is checked at each - stage of the authentication to ensure that the asked for - operation has not changed. - method: - The method this session was initiated with, this is checked at each - stage of the authentication to ensure that the asked for - operation has not changed. - description: - A string description of the operation that the current - authentication is authorising. - Returns: - The newly created session. - Raises: - StoreError if a unique session ID cannot be generated. - """ - # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) - - # autogen a session ID and try to create it. We may clash, so just - # try a few times till one goes through, giving up eventually. - attempts = 0 - while attempts < 5: - session_id = stringutils.random_string(24) - - try: - await self.db.simple_insert( - table="ui_auth_sessions", - values={ - "session_id": session_id, - "clientdict": clientdict_json, - "uri": uri, - "method": method, - "description": description, - "serverdict": "{}", - "creation_time": self.hs.get_clock().time_msec(), - }, - desc="create_ui_auth_session", - ) - return UIAuthSessionData( - session_id, clientdict, uri, method, description - ) - except self.db.engine.module.IntegrityError: - attempts += 1 - raise StoreError(500, "Couldn't generate a session ID.") - - async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: - """Retrieve a UI auth session. - - Args: - session_id: The ID of the session. - Returns: - A dict containing the device information. - Raises: - StoreError if the session is not found. - """ - result = await self.db.simple_select_one( - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - retcols=("clientdict", "uri", "method", "description"), - desc="get_ui_auth_session", - ) - - result["clientdict"] = db_to_json(result["clientdict"]) - - return UIAuthSessionData(session_id, **result) - - async def mark_ui_auth_stage_complete( - self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict], - ): - """ - Mark a session stage as completed. - - Args: - session_id: The ID of the corresponding session. - stage_type: The completed stage type. - result: The result of the stage verification. - Raises: - StoreError if the session cannot be found. - """ - # Add (or update) the results of the current stage to the database. - # - # Note that we need to allow for the same stage to complete multiple - # times here so that registration is idempotent. - try: - await self.db.simple_upsert( - table="ui_auth_sessions_credentials", - keyvalues={"session_id": session_id, "stage_type": stage_type}, - values={"result": json.dumps(result)}, - desc="mark_ui_auth_stage_complete", - ) - except self.db.engine.module.IntegrityError: - raise StoreError(400, "Unknown session ID: %s" % (session_id,)) - - async def get_completed_ui_auth_stages( - self, session_id: str - ) -> Dict[str, Union[str, bool, JsonDict]]: - """ - Retrieve the completed stages of a UI authentication session. - - Args: - session_id: The ID of the session. - Returns: - The completed stages mapped to the result of the verification of - that auth-type. - """ - results = {} - for row in await self.db.simple_select_list( - table="ui_auth_sessions_credentials", - keyvalues={"session_id": session_id}, - retcols=("stage_type", "result"), - desc="get_completed_ui_auth_stages", - ): - results[row["stage_type"]] = db_to_json(row["result"]) - - return results - - async def set_ui_auth_clientdict( - self, session_id: str, clientdict: JsonDict - ) -> None: - """ - Store an updated clientdict for a given session ID. - - Args: - session_id: The ID of this session as returned from check_auth - clientdict: - The dictionary from the client root level, not the 'auth' key. - """ - # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) - - await self.db.simple_update_one( - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - updatevalues={"clientdict": clientdict_json}, - desc="set_ui_auth_client_dict", - ) - - async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): - """ - Store a key-value pair into the sessions data associated with this - request. This data is stored server-side and cannot be modified by - the client. - - Args: - session_id: The ID of this session as returned from check_auth - key: The key to store the data under - value: The data to store - Raises: - StoreError if the session cannot be found. - """ - await self.db.runInteraction( - "set_ui_auth_session_data", - self._set_ui_auth_session_data_txn, - session_id, - key, - value, - ) - - def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): - # Get the current value. - result = self.db.simple_select_one_txn( - txn, - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - retcols=("serverdict",), - ) - - # Update it and add it back to the database. - serverdict = db_to_json(result["serverdict"]) - serverdict[key] = value - - self.db.simple_update_one_txn( - txn, - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - updatevalues={"serverdict": json.dumps(serverdict)}, - ) - - async def get_ui_auth_session_data( - self, session_id: str, key: str, default: Optional[Any] = None - ) -> Any: - """ - Retrieve data stored with set_session_data - - Args: - session_id: The ID of this session as returned from check_auth - key: The key to store the data under - default: Value to return if the key has not been set - Raises: - StoreError if the session cannot be found. - """ - result = await self.db.simple_select_one( - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - retcols=("serverdict",), - desc="get_ui_auth_session_data", - ) - - serverdict = db_to_json(result["serverdict"]) - - return serverdict.get(key, default) - - -class UIAuthStore(UIAuthWorkerStore): - def delete_old_ui_auth_sessions(self, expiration_time: int): - """ - Remove sessions which were last used earlier than the expiration time. - - Args: - expiration_time: The latest time that is still considered valid. - This is an epoch time in milliseconds. - - """ - return self.db.runInteraction( - "delete_old_ui_auth_sessions", - self._delete_old_ui_auth_sessions_txn, - expiration_time, - ) - - def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): - # Get the expired sessions. - sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" - txn.execute(sql, [expiration_time]) - session_ids = [r[0] for r in txn.fetchall()] - - # Delete the corresponding completed credentials. - self.db.simple_delete_many_txn( - txn, - table="ui_auth_sessions_credentials", - column="session_id", - iterable=session_ids, - keyvalues={}, - ) - - # Finally, delete the sessions. - self.db.simple_delete_many_txn( - txn, - table="ui_auth_sessions", - column="session_id", - iterable=session_ids, - keyvalues={}, - ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py deleted file mode 100644 index 942e51fd3a..0000000000 --- a/synapse/storage/data_stores/main/user_directory.py +++ /dev/null @@ -1,837 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re - -from twisted.internet import defer - -from synapse.api.constants import EventTypes, JoinRules -from synapse.storage.data_stores.main.state import StateFilter -from synapse.storage.data_stores.main.state_deltas import StateDeltasStore -from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import get_domain_from_id, get_localpart_from_id -from synapse.util.caches.descriptors import cached - -logger = logging.getLogger(__name__) - - -TEMP_TABLE = "_temp_populate_user_directory" - - -class UserDirectoryBackgroundUpdateStore(StateDeltasStore): - - # How many records do we calculate before sending it to - # add_users_who_share_private_rooms? - SHARE_PRIVATE_WORKING_SET = 500 - - def __init__(self, database: Database, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) - - self.server_name = hs.hostname - - self.db.updates.register_background_update_handler( - "populate_user_directory_createtables", - self._populate_user_directory_createtables, - ) - self.db.updates.register_background_update_handler( - "populate_user_directory_process_rooms", - self._populate_user_directory_process_rooms, - ) - self.db.updates.register_background_update_handler( - "populate_user_directory_process_users", - self._populate_user_directory_process_users, - ) - self.db.updates.register_background_update_handler( - "populate_user_directory_cleanup", self._populate_user_directory_cleanup - ) - - @defer.inlineCallbacks - def _populate_user_directory_createtables(self, progress, batch_size): - - # Get all the rooms that we want to process. - def _make_staging_area(txn): - sql = ( - "CREATE TABLE IF NOT EXISTS " - + TEMP_TABLE - + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)" - ) - txn.execute(sql) - - sql = ( - "CREATE TABLE IF NOT EXISTS " - + TEMP_TABLE - + "_position(position TEXT NOT NULL)" - ) - txn.execute(sql) - - # Get rooms we want to process from the database - sql = """ - SELECT room_id, count(*) FROM current_state_events - GROUP BY room_id - """ - txn.execute(sql) - rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) - del rooms - - # If search all users is on, get all the users we want to add. - if self.hs.config.user_directory_search_all_users: - sql = ( - "CREATE TABLE IF NOT EXISTS " - + TEMP_TABLE - + "_users(user_id TEXT NOT NULL)" - ) - txn.execute(sql) - - txn.execute("SELECT name FROM users") - users = [{"user_id": x[0]} for x in txn.fetchall()] - - self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) - - new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.db.runInteraction( - "populate_user_directory_temp_build", _make_staging_area - ) - yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) - - yield self.db.updates._end_background_update( - "populate_user_directory_createtables" - ) - return 1 - - @defer.inlineCallbacks - def _populate_user_directory_cleanup(self, progress, batch_size): - """ - Update the user directory stream position, then clean up the old tables. - """ - position = yield self.db.simple_select_one_onecol( - TEMP_TABLE + "_position", None, "position" - ) - yield self.update_user_directory_stream_pos(position) - - def _delete_staging_area(txn): - txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") - txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") - txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - - yield self.db.runInteraction( - "populate_user_directory_cleanup", _delete_staging_area - ) - - yield self.db.updates._end_background_update("populate_user_directory_cleanup") - return 1 - - @defer.inlineCallbacks - def _populate_user_directory_process_rooms(self, progress, batch_size): - """ - Args: - progress (dict) - batch_size (int): Maximum number of state events to process - per cycle. - """ - state = self.hs.get_state_handler() - - # If we don't have progress filed, delete everything. - if not progress: - yield self.delete_all_from_user_dir() - - def _get_next_batch(txn): - # Only fetch 250 rooms, so we don't fetch too many at once, even - # if those 250 rooms have less than batch_size state events. - sql = """ - SELECT room_id, events FROM %s - ORDER BY events DESC - LIMIT 250 - """ % ( - TEMP_TABLE + "_rooms", - ) - txn.execute(sql) - rooms_to_work_on = txn.fetchall() - - if not rooms_to_work_on: - return None - - # Get how many are left to process, so we can give status on how - # far we are in processing - txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") - progress["remaining"] = txn.fetchone()[0] - - return rooms_to_work_on - - rooms_to_work_on = yield self.db.runInteraction( - "populate_user_directory_temp_read", _get_next_batch - ) - - # No more rooms -- complete the transaction. - if not rooms_to_work_on: - yield self.db.updates._end_background_update( - "populate_user_directory_process_rooms" - ) - return 1 - - logger.debug( - "Processing the next %d rooms of %d remaining" - % (len(rooms_to_work_on), progress["remaining"]) - ) - - processed_event_count = 0 - - for room_id, event_count in rooms_to_work_on: - is_in_room = yield self.is_host_joined(room_id, self.server_name) - - if is_in_room: - is_public = yield self.is_room_world_readable_or_publicly_joinable( - room_id - ) - - users_with_profile = yield defer.ensureDeferred( - state.get_current_users_in_room(room_id) - ) - user_ids = set(users_with_profile) - - # Update each user in the user directory. - for user_id, profile in users_with_profile.items(): - yield self.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) - - to_insert = set() - - if is_public: - for user_id in user_ids: - if self.get_if_app_services_interested_in_user(user_id): - continue - - to_insert.add(user_id) - - if to_insert: - yield self.add_users_in_public_rooms(room_id, to_insert) - to_insert.clear() - else: - for user_id in user_ids: - if not self.hs.is_mine_id(user_id): - continue - - if self.get_if_app_services_interested_in_user(user_id): - continue - - for other_user_id in user_ids: - if user_id == other_user_id: - continue - - user_set = (user_id, other_user_id) - to_insert.add(user_set) - - # If it gets too big, stop and write to the database - # to prevent storing too much in RAM. - if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET: - yield self.add_users_who_share_private_room( - room_id, to_insert - ) - to_insert.clear() - - if to_insert: - yield self.add_users_who_share_private_room(room_id, to_insert) - to_insert.clear() - - # We've finished a room. Delete it from the table. - yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) - # Update the remaining counter. - progress["remaining"] -= 1 - yield self.db.runInteraction( - "populate_user_directory", - self.db.updates._background_update_progress_txn, - "populate_user_directory_process_rooms", - progress, - ) - - processed_event_count += event_count - - if processed_event_count > batch_size: - # Don't process any more rooms, we've hit our batch size. - return processed_event_count - - return processed_event_count - - @defer.inlineCallbacks - def _populate_user_directory_process_users(self, progress, batch_size): - """ - If search_all_users is enabled, add all of the users to the user directory. - """ - if not self.hs.config.user_directory_search_all_users: - yield self.db.updates._end_background_update( - "populate_user_directory_process_users" - ) - return 1 - - def _get_next_batch(txn): - sql = "SELECT user_id FROM %s LIMIT %s" % ( - TEMP_TABLE + "_users", - str(batch_size), - ) - txn.execute(sql) - users_to_work_on = txn.fetchall() - - if not users_to_work_on: - return None - - users_to_work_on = [x[0] for x in users_to_work_on] - - # Get how many are left to process, so we can give status on how - # far we are in processing - sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users" - txn.execute(sql) - progress["remaining"] = txn.fetchone()[0] - - return users_to_work_on - - users_to_work_on = yield self.db.runInteraction( - "populate_user_directory_temp_read", _get_next_batch - ) - - # No more users -- complete the transaction. - if not users_to_work_on: - yield self.db.updates._end_background_update( - "populate_user_directory_process_users" - ) - return 1 - - logger.debug( - "Processing the next %d users of %d remaining" - % (len(users_to_work_on), progress["remaining"]) - ) - - for user_id in users_to_work_on: - profile = yield self.get_profileinfo(get_localpart_from_id(user_id)) - yield self.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) - - # We've finished processing a user. Delete it from the table. - yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) - # Update the remaining counter. - progress["remaining"] -= 1 - yield self.db.runInteraction( - "populate_user_directory", - self.db.updates._background_update_progress_txn, - "populate_user_directory_process_users", - progress, - ) - - return len(users_to_work_on) - - @defer.inlineCallbacks - def is_room_world_readable_or_publicly_joinable(self, room_id): - """Check if the room is either world_readable or publically joinable - """ - - # Create a state filter that only queries join and history state event - types_to_filter = ( - (EventTypes.JoinRules, ""), - (EventTypes.RoomHistoryVisibility, ""), - ) - - current_state_ids = yield self.get_filtered_current_state_ids( - room_id, StateFilter.from_types(types_to_filter) - ) - - join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) - if join_rules_id: - join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) - if join_rule_ev: - if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: - return True - - hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) - if hist_vis_id: - hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) - if hist_vis_ev: - if hist_vis_ev.content.get("history_visibility") == "world_readable": - return True - - return False - - def update_profile_in_user_dir(self, user_id, display_name, avatar_url): - """ - Update or add a user's profile in the user directory. - """ - - def _update_profile_in_user_dir_txn(txn): - new_entry = self.db.simple_upsert_txn( - txn, - table="user_directory", - keyvalues={"user_id": user_id}, - values={"display_name": display_name, "avatar_url": avatar_url}, - lock=False, # We're only inserter - ) - - if isinstance(self.database_engine, PostgresEngine): - # We weight the localpart most highly, then display name and finally - # server name - if self.database_engine.can_native_upsert: - sql = """ - INSERT INTO user_directory_search(user_id, vector) - VALUES (?, - setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector - """ - txn.execute( - sql, - ( - user_id, - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - ), - ) - else: - # TODO: Remove this code after we've bumped the minimum version - # of postgres to always support upserts, so we can get rid of - # `new_entry` usage - if new_entry is True: - sql = """ - INSERT INTO user_directory_search(user_id, vector) - VALUES (?, - setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - ) - """ - txn.execute( - sql, - ( - user_id, - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - ), - ) - elif new_entry is False: - sql = """ - UPDATE user_directory_search - SET vector = setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - WHERE user_id = ? - """ - txn.execute( - sql, - ( - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - user_id, - ), - ) - else: - raise RuntimeError( - "upsert returned None when 'can_native_upsert' is False" - ) - elif isinstance(self.database_engine, Sqlite3Engine): - value = "%s %s" % (user_id, display_name) if display_name else user_id - self.db.simple_upsert_txn( - txn, - table="user_directory_search", - keyvalues={"user_id": user_id}, - values={"value": value}, - lock=False, # We're only inserter - ) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - - return self.db.runInteraction( - "update_profile_in_user_dir", _update_profile_in_user_dir_txn - ) - - def add_users_who_share_private_room(self, room_id, user_id_tuples): - """Insert entries into the users_who_share_private_rooms table. The first - user should be a local user. - - Args: - room_id (str) - user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. - """ - - def _add_users_who_share_room_txn(txn): - self.db.simple_upsert_many_txn( - txn, - table="users_who_share_private_rooms", - key_names=["user_id", "other_user_id", "room_id"], - key_values=[ - (user_id, other_user_id, room_id) - for user_id, other_user_id in user_id_tuples - ], - value_names=(), - value_values=None, - ) - - return self.db.runInteraction( - "add_users_who_share_room", _add_users_who_share_room_txn - ) - - def add_users_in_public_rooms(self, room_id, user_ids): - """Insert entries into the users_who_share_private_rooms table. The first - user should be a local user. - - Args: - room_id (str) - user_ids (list[str]) - """ - - def _add_users_in_public_rooms_txn(txn): - - self.db.simple_upsert_many_txn( - txn, - table="users_in_public_rooms", - key_names=["user_id", "room_id"], - key_values=[(user_id, room_id) for user_id in user_ids], - value_names=(), - value_values=None, - ) - - return self.db.runInteraction( - "add_users_in_public_rooms", _add_users_in_public_rooms_txn - ) - - def delete_all_from_user_dir(self): - """Delete the entire user directory - """ - - def _delete_all_from_user_dir_txn(txn): - txn.execute("DELETE FROM user_directory") - txn.execute("DELETE FROM user_directory_search") - txn.execute("DELETE FROM users_in_public_rooms") - txn.execute("DELETE FROM users_who_share_private_rooms") - txn.call_after(self.get_user_in_directory.invalidate_all) - - return self.db.runInteraction( - "delete_all_from_user_dir", _delete_all_from_user_dir_txn - ) - - @cached() - def get_user_in_directory(self, user_id): - return self.db.simple_select_one( - table="user_directory", - keyvalues={"user_id": user_id}, - retcols=("display_name", "avatar_url"), - allow_none=True, - desc="get_user_in_directory", - ) - - def update_user_directory_stream_pos(self, stream_id): - return self.db.simple_update_one( - table="user_directory_stream_pos", - keyvalues={}, - updatevalues={"stream_id": stream_id}, - desc="update_user_directory_stream_pos", - ) - - -class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): - - # How many records do we calculate before sending it to - # add_users_who_share_private_rooms? - SHARE_PRIVATE_WORKING_SET = 500 - - def __init__(self, database: Database, db_conn, hs): - super(UserDirectoryStore, self).__init__(database, db_conn, hs) - - def remove_from_user_dir(self, user_id): - def _remove_from_user_dir_txn(txn): - self.db.simple_delete_txn( - txn, table="user_directory", keyvalues={"user_id": user_id} - ) - self.db.simple_delete_txn( - txn, table="user_directory_search", keyvalues={"user_id": user_id} - ) - self.db.simple_delete_txn( - txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} - ) - self.db.simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"user_id": user_id}, - ) - self.db.simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"other_user_id": user_id}, - ) - txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - - return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) - - @defer.inlineCallbacks - def get_users_in_dir_due_to_room(self, room_id): - """Get all user_ids that are in the room directory because they're - in the given room_id - """ - user_ids_share_pub = yield self.db.simple_select_onecol( - table="users_in_public_rooms", - keyvalues={"room_id": room_id}, - retcol="user_id", - desc="get_users_in_dir_due_to_room", - ) - - user_ids_share_priv = yield self.db.simple_select_onecol( - table="users_who_share_private_rooms", - keyvalues={"room_id": room_id}, - retcol="other_user_id", - desc="get_users_in_dir_due_to_room", - ) - - user_ids = set(user_ids_share_pub) - user_ids.update(user_ids_share_priv) - - return user_ids - - def remove_user_who_share_room(self, user_id, room_id): - """ - Deletes entries in the users_who_share_*_rooms table. The first - user should be a local user. - - Args: - user_id (str) - room_id (str) - """ - - def _remove_user_who_share_room_txn(txn): - self.db.simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"user_id": user_id, "room_id": room_id}, - ) - self.db.simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"other_user_id": user_id, "room_id": room_id}, - ) - self.db.simple_delete_txn( - txn, - table="users_in_public_rooms", - keyvalues={"user_id": user_id, "room_id": room_id}, - ) - - return self.db.runInteraction( - "remove_user_who_share_room", _remove_user_who_share_room_txn - ) - - @defer.inlineCallbacks - def get_user_dir_rooms_user_is_in(self, user_id): - """ - Returns the rooms that a user is in. - - Args: - user_id(str): Must be a local user - - Returns: - list: user_id - """ - rows = yield self.db.simple_select_onecol( - table="users_who_share_private_rooms", - keyvalues={"user_id": user_id}, - retcol="room_id", - desc="get_rooms_user_is_in", - ) - - pub_rows = yield self.db.simple_select_onecol( - table="users_in_public_rooms", - keyvalues={"user_id": user_id}, - retcol="room_id", - desc="get_rooms_user_is_in", - ) - - users = set(pub_rows) - users.update(rows) - return list(users) - - @defer.inlineCallbacks - def get_rooms_in_common_for_users(self, user_id, other_user_id): - """Given two user_ids find out the list of rooms they share. - """ - sql = """ - SELECT room_id FROM ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) AS f1 INNER JOIN ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) f2 USING (room_id) - """ - - rows = yield self.db.execute( - "get_rooms_in_common_for_users", None, sql, user_id, other_user_id - ) - - return [room_id for room_id, in rows] - - def get_user_directory_stream_pos(self): - return self.db.simple_select_one_onecol( - table="user_directory_stream_pos", - keyvalues={}, - retcol="stream_id", - desc="get_user_directory_stream_pos", - ) - - @defer.inlineCallbacks - def search_user_dir(self, user_id, search_term, limit): - """Searches for users in directory - - Returns: - dict of the form:: - - { - "limited": , # whether there were more results or not - "results": [ # Ordered by best match first - { - "user_id": , - "display_name": , - "avatar_url": - } - ] - } - """ - - if self.hs.config.user_directory_search_all_users: - join_args = (user_id,) - where_clause = "user_id != ?" - else: - join_args = (user_id,) - where_clause = """ - ( - EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id) - OR EXISTS ( - SELECT 1 FROM users_who_share_private_rooms - WHERE user_id = ? AND other_user_id = t.user_id - ) - ) - """ - - if isinstance(self.database_engine, PostgresEngine): - full_query, exact_query, prefix_query = _parse_query_postgres(search_term) - - # We order by rank and then if they have profile info - # The ranking algorithm is hand tweaked for "best" results. Broadly - # the idea is we give a higher weight to exact matches. - # The array of numbers are the weights for the various part of the - # search: (domain, _, display name, localpart) - sql = """ - SELECT d.user_id AS user_id, display_name, avatar_url - FROM user_directory_search as t - INNER JOIN user_directory AS d USING (user_id) - WHERE - %s - AND vector @@ to_tsquery('english', ?) - ORDER BY - (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) - * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END) - * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END) - * ( - 3 * ts_rank_cd( - '{0.1, 0.1, 0.9, 1.0}', - vector, - to_tsquery('english', ?), - 8 - ) - + ts_rank_cd( - '{0.1, 0.1, 0.9, 1.0}', - vector, - to_tsquery('english', ?), - 8 - ) - ) - DESC, - display_name IS NULL, - avatar_url IS NULL - LIMIT ? - """ % ( - where_clause, - ) - args = join_args + (full_query, exact_query, prefix_query, limit + 1) - elif isinstance(self.database_engine, Sqlite3Engine): - search_query = _parse_query_sqlite(search_term) - - sql = """ - SELECT d.user_id AS user_id, display_name, avatar_url - FROM user_directory_search as t - INNER JOIN user_directory AS d USING (user_id) - WHERE - %s - AND value MATCH ? - ORDER BY - rank(matchinfo(user_directory_search)) DESC, - display_name IS NULL, - avatar_url IS NULL - LIMIT ? - """ % ( - where_clause, - ) - args = join_args + (search_query, limit + 1) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - results = yield self.db.execute( - "search_user_dir", self.db.cursor_to_dict, sql, *args - ) - - limited = len(results) > limit - - return {"limited": limited, "results": results} - - -def _parse_query_sqlite(search_term): - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. - - We specifically add both a prefix and non prefix matching term so that - exact matches get ranked higher. - """ - - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - return " & ".join("(%s* OR %s)" % (result, result) for result in results) - - -def _parse_query_postgres(search_term): - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. - """ - - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - - both = " & ".join("(%s:* | %s)" % (result, result) for result in results) - exact = " & ".join("%s" % (result,) for result in results) - prefix = " & ".join("%s:*" % (result,) for result in results) - - return both, exact, prefix diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py deleted file mode 100644 index d3038ff06d..0000000000 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ /dev/null @@ -1,113 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import operator - -from synapse.storage._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList - - -class UserErasureWorkerStore(SQLBaseStore): - @cached() - def is_user_erased(self, user_id): - """ - Check if the given user id has requested erasure - - Args: - user_id (str): full user id to check - - Returns: - Deferred[bool]: True if the user has requested erasure - """ - return self.db.simple_select_onecol( - table="erased_users", - keyvalues={"user_id": user_id}, - retcol="1", - desc="is_user_erased", - ).addCallback(operator.truth) - - @cachedList( - cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True - ) - def are_users_erased(self, user_ids): - """ - Checks which users in a list have requested erasure - - Args: - user_ids (iterable[str]): full user id to check - - Returns: - Deferred[dict[str, bool]]: - for each user, whether the user has requested erasure. - """ - # this serves the dual purpose of (a) making sure we can do len and - # iterate it multiple times, and (b) avoiding duplicates. - user_ids = tuple(set(user_ids)) - - rows = yield self.db.simple_select_many_batch( - table="erased_users", - column="user_id", - iterable=user_ids, - retcols=("user_id",), - desc="are_users_erased", - ) - erased_users = {row["user_id"] for row in rows} - - res = {u: u in erased_users for u in user_ids} - return res - - -class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id: str) -> None: - """Indicate that user_id wishes their message history to be erased. - - Args: - user_id: full user_id to be erased - """ - - def f(txn): - # first check if they are already in the list - txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) - if txn.fetchone(): - return - - # they are not already there: do the insert. - txn.execute("INSERT INTO erased_users (user_id) VALUES (?)", (user_id,)) - - self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - - return self.db.runInteraction("mark_user_erased", f) - - def mark_user_not_erased(self, user_id: str) -> None: - """Indicate that user_id is no longer erased. - - Args: - user_id: full user_id to be un-erased - """ - - def f(txn): - # first check if they are already in the list - txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) - if not txn.fetchone(): - return - - # They are there, delete them. - self.simple_delete_one_txn( - txn, "erased_users", keyvalues={"user_id": user_id} - ) - - self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - - return self.db.runInteraction("mark_user_not_erased", f) diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py deleted file mode 100644 index 86e09f6229..0000000000 --- a/synapse/storage/data_stores/state/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py deleted file mode 100644 index be1fe97d79..0000000000 --- a/synapse/storage/data_stores/state/bg_updates.py +++ /dev/null @@ -1,372 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine -from synapse.storage.state import StateFilter - -logger = logging.getLogger(__name__) - - -MAX_STATE_DELTA_HOPS = 100 - - -class StateGroupBackgroundUpdateStore(SQLBaseStore): - """Defines functions related to state groups needed to run the state backgroud - updates. - """ - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """ - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() - ): - results = {group: {} for group in groups} - - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - - if isinstance(self.database_engine, PostgresEngine): - # Temporarily disable sequential scans in this transaction. This is - # a temporary hack until we can add the right indices in - txn.execute("SET LOCAL enable_seqscan=off") - - # The below query walks the state_group tree so that the "state" - # table includes all state_groups in the tree. It then joins - # against `state_groups_state` to fetch the latest state. - # It assumes that previous state groups are always numerically - # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT DISTINCT ON (type, state_key) - type, state_key, event_id - FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) %s - ORDER BY type, state_key, state_group DESC - """ - - for group in groups: - args = [group] - args.extend(where_args) - - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id - else: - max_entries_returned = state_filter.max_entries_returned() - - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - for group in groups: - next_group = group - - while next_group: - # We did this before by getting the list of group ids, and - # then passing that list to sqlite to get latest event for - # each (type, state_key). However, that was terribly slow - # without the right indices (which we can't add until - # after we finish deduping state, which requires this func) - args = [next_group] - args.extend(where_args) - - txn.execute( - "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? " + where_clause, - args, - ) - results[group].update( - ((typ, state_key), event_id) - for typ, state_key, event_id in txn - if (typ, state_key) not in results[group] - ) - - # If the number of entries in the (type,state_key)->event_id dict - # matches the number of (type,state_keys) types we were searching - # for, then we must have found them all, so no need to go walk - # further down the tree... UNLESS our types filter contained - # wildcards (i.e. Nones) in which case we have to do an exhaustive - # search - if ( - max_entries_returned is not None - and len(results[group]) == max_entries_returned - ): - break - - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - - return results - - -class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): - - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" - - def __init__(self, database: Database, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state - ) - self.db.updates.register_background_index_update( - self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, - index_name="state_groups_room_id_idx", - table="state_groups", - columns=["room_id"], - ) - - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): - """This background update will slowly deduplicate state by reencoding - them as deltas. - """ - last_state_group = progress.get("last_state_group", 0) - rows_inserted = progress.get("rows_inserted", 0) - max_group = progress.get("max_group", None) - - BATCH_SIZE_SCALE_FACTOR = 100 - - batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) - - if max_group is None: - rows = yield self.db.execute( - "_background_deduplicate_state", - None, - "SELECT coalesce(max(id), 0) FROM state_groups", - ) - max_group = rows[0][0] - - def reindex_txn(txn): - new_last_state_group = last_state_group - for count in range(batch_size): - txn.execute( - "SELECT id, room_id FROM state_groups" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC" - " LIMIT 1", - (new_last_state_group, max_group), - ) - row = txn.fetchone() - if row: - state_group, room_id = row - - if not row or not state_group: - return True, count - - txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", - (state_group,), - ) - - # If we reach a point where we've already started inserting - # edges we should stop. - if txn.fetchall(): - return True, count - - txn.execute( - "SELECT coalesce(max(id), 0) FROM state_groups" - " WHERE id < ? AND room_id = ?", - (state_group, room_id), - ) - (prev_group,) = txn.fetchone() - new_last_state_group = state_group - - if prev_group: - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if potential_hops >= MAX_STATE_DELTA_HOPS: - # We want to ensure chains are at most this long,# - # otherwise read performance degrades. - continue - - prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group] - ) - prev_state = prev_state[prev_group] - - curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group] - ) - curr_state = curr_state[state_group] - - if not set(prev_state.keys()) - set(curr_state.keys()): - # We can only do a delta if the current has a strict super set - # of keys - - delta_state = { - key: value - for key, value in curr_state.items() - if prev_state.get(key, None) != value - } - - self.db.simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, - ) - - self.db.simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in delta_state.items() - ], - ) - - progress = { - "last_state_group": state_group, - "rows_inserted": rows_inserted + batch_size, - "max_group": max_group, - } - - self.db.updates._background_update_progress_txn( - txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress - ) - - return False, batch_size - - finished, result = yield self.db.runInteraction( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn - ) - - if finished: - yield self.db.updates._end_background_update( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME - ) - - return result * BATCH_SIZE_SCALE_FACTOR - - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): - conn.rollback() - if isinstance(self.database_engine, PostgresEngine): - # postgres insists on autocommit for the index - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - finally: - conn.set_session(autocommit=False) - else: - txn = conn.cursor() - txn.execute( - "CREATE INDEX state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - - yield self.db.runWithConnection(reindex_txn) - - yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - - return 1 diff --git a/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql deleted file mode 100644 index ae09fa0065..0000000000 --- a/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP INDEX IF EXISTS state_groups_state_tuple; diff --git a/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql deleted file mode 100644 index e85699e82e..0000000000 --- a/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -/* We used to create a table called current_state_resets, but this is no - * longer used and is removed in delta 54. - */ - -/* The outlier events that have aquired a state group typically through - * backfill. This is tracked separately to the events table, as assigning a - * state group change the position of the existing event in the stream - * ordering. - * However since a stream_ordering is assigned in persist_event for the - * (event, state) pair, we can use that stream_ordering to identify when - * the new state was assigned for the event. - */ -CREATE TABLE IF NOT EXISTS ex_outlier_stream( - event_stream_ordering BIGINT PRIMARY KEY NOT NULL, - event_id TEXT NOT NULL, - state_group BIGINT NOT NULL -); diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql deleted file mode 100644 index 1450313bfa..0000000000 --- a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - --- The following indices are redundant, other indices are equivalent or --- supersets -DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY diff --git a/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql deleted file mode 100644 index 33980d02f0..0000000000 --- a/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/data_stores/state/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql deleted file mode 100644 index 0f1fa68a89..0000000000 --- a/synapse/storage/data_stores/state/schema/delta/35/state.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE state_group_edges( - state_group BIGINT NOT NULL, - prev_state_group BIGINT NOT NULL -); - -CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); diff --git a/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql deleted file mode 100644 index 97e5067ef4..0000000000 --- a/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('state_group_state_deduplication', '{}'); diff --git a/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py deleted file mode 100644 index 9fd1ccf6f7..0000000000 --- a/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.engines import PostgresEngine - - -def run_create(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - # if we already have some state groups, we want to start making new - # ones with a higher id. - cur.execute("SELECT max(id) FROM state_groups") - row = cur.fetchone() - - if row[0] is None: - start_val = 1 - else: - start_val = row[0] + 1 - - cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql deleted file mode 100644 index 7916ef18b2..0000000000 --- a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('state_groups_room_id_idx', '{}'); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql deleted file mode 100644 index 35f97d6b3d..0000000000 --- a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE state_groups ( - id BIGINT PRIMARY KEY, - room_id TEXT NOT NULL, - event_id TEXT NOT NULL -); - -CREATE TABLE state_groups_state ( - state_group BIGINT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - event_id TEXT NOT NULL -); - -CREATE TABLE state_group_edges ( - state_group BIGINT NOT NULL, - prev_state_group BIGINT NOT NULL -); - -CREATE INDEX state_group_edges_idx ON state_group_edges (state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group); -CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres deleted file mode 100644 index fcd926c9fb..0000000000 --- a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE SEQUENCE state_group_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py deleted file mode 100644 index 7dada7f75f..0000000000 --- a/synapse/storage/data_stores/state/store.py +++ /dev/null @@ -1,644 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from collections import namedtuple -from typing import Dict, Iterable, List, Set, Tuple - -from twisted.internet import defer - -from synapse.api.constants import EventTypes -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.database import Database -from synapse.storage.state import StateFilter -from synapse.storage.types import Cursor -from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import StateMap -from synapse.util.caches.descriptors import cached -from synapse.util.caches.dictionary_cache import DictionaryCache - -logger = logging.getLogger(__name__) - - -MAX_STATE_DELTA_HOPS = 100 - - -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 - - -class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): - """A data store for fetching/storing state groups. - """ - - def __init__(self, database: Database, db_conn, hs): - super(StateGroupDataStore, self).__init__(database, db_conn, hs) - - # Originally the state store used a single DictionaryCache to cache the - # event IDs for the state types in a given state group to avoid hammering - # on the state_group* tables. - # - # The point of using a DictionaryCache is that it can cache a subset - # of the state events for a given state group (i.e. a subset of the keys for a - # given dict which is an entry in the cache for a given state group ID). - # - # However, this poses problems when performing complicated queries - # on the store - for instance: "give me all the state for this group, but - # limit members to this subset of users", as DictionaryCache's API isn't - # rich enough to say "please cache any of these fields, apart from this subset". - # This is problematic when lazy loading members, which requires this behaviour, - # as without it the cache has no choice but to speculatively load all - # state events for the group, which negates the efficiency being sought. - # - # Rather than overcomplicating DictionaryCache's API, we instead split the - # state_group_cache into two halves - one for tracking non-member events, - # and the other for tracking member_events. This means that lazy loading - # queries can be made in a cache-friendly manner by querying both caches - # separately and then merging the result. So for the example above, you - # would query the members cache for a specific subset of state keys - # (which DictionaryCache will handle efficiently and fine) and the non-members - # cache for all state (which DictionaryCache will similarly handle fine) - # and then just merge the results together. - # - # We size the non-members cache to be smaller than the members cache as the - # vast majority of state in Matrix (today) is member events. - - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", - # TODO: this hasn't been tuned yet - 50000, - ) - self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", 500000, - ) - - def get_max_state_group_txn(txn: Cursor): - txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") - return txn.fetchone()[0] - - self._state_group_seq_gen = build_sequence_generator( - self.database_engine, get_max_state_group_txn, "state_group_id_seq" - ) - - @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): - """Given a state group try to return a previous group and a delta between - the old and the new. - - Returns: - (prev_group, delta_ids), where both may be None. - """ - - def _get_state_group_delta_txn(txn): - prev_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - retcol="prev_state_group", - allow_none=True, - ) - - if not prev_group: - return _GetStateGroupDelta(None, None) - - delta_ids = self.db.simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), - ) - - return _GetStateGroupDelta( - prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, - ) - - return self.db.runInteraction( - "get_state_group_delta", _get_state_group_delta_txn - ) - - async def _get_state_groups_from_groups( - self, groups: List[int], state_filter: StateFilter - ) -> Dict[int, StateMap[str]]: - """Returns the state groups for a given set of groups from the - database, filtering on types of state events. - - Args: - groups: list of state group IDs to query - state_filter: The state filter used to fetch state - from the database. - Returns: - Dict of state group to state map. - """ - results = {} - - chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] - for chunk in chunks: - res = await self.db.runInteraction( - "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, - chunk, - state_filter, - ) - results.update(res) - - return results - - def _get_state_for_group_using_cache(self, cache, group, state_filter): - """Checks if group is in cache. See `_get_state_for_groups` - - Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. - """ - is_all, known_absent, state_dict_ids = cache.get(group) - - if is_all or state_filter.is_full(): - # Either we have everything or want everything, either way - # `is_all` tells us whether we've gotten everything. - return state_filter.filter_state(state_dict_ids), is_all - - # tracks whether any of our requested types are missing from the cache - missing_types = False - - if state_filter.has_wildcards(): - # We don't know if we fetched all the state keys for the types in - # the filter that are wildcards, so we have to assume that we may - # have missed some. - missing_types = True - else: - # There aren't any wild cards, so `concrete_types()` returns the - # complete list of event types we're wanting. - for key in state_filter.concrete_types(): - if key not in state_dict_ids and key not in known_absent: - missing_types = True - break - - return state_filter.filter_state(state_dict_ids), not missing_types - - async def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ) -> Dict[int, StateMap[str]]: - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups: list of state groups for which we want - to get the state. - state_filter: The state filter used to fetch state - from the database. - Returns: - Dict of state group to state map. - """ - - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - - (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) - - # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - - if not incomplete_groups: - return state - - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = await self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) - - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in group_to_state_dict.items(): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - - return state - - def _get_state_for_groups_using_cache( - self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter - ) -> Tuple[Dict[int, StateMap[str]], Set[int]]: - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key, querying from a specific cache. - - Args: - groups: list of state groups for which we want to get the state. - cache: the cache of group ids to state dicts which - we will pass through - either the normal state cache or the - specific members state cache. - state_filter: The state filter used to fetch state from the - database. - - Returns: - Tuple of dict of state_group_id to state map of entries in the - cache, and the state group ids either missing from the cache or - incomplete. - """ - results = {} - incomplete_groups = set() - for group in set(groups): - state_dict_ids, got_all = self._get_state_for_group_using_cache( - cache, group, state_filter - ) - results[group] = state_dict_ids - - if not got_all: - incomplete_groups.add(group) - - return results, incomplete_groups - - def _insert_into_cache( - self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): - """Inserts results from querying the database into the relevant cache. - - Args: - group_to_state_dict (dict): The new entries pulled from database. - Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state - from the database. - cache_seq_num_members (int): Sequence number of member cache since - last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since - last lookup in cache - """ - - # We need to work out which types we've fetched from the DB for the - # member vs non-member caches. This should be as accurate as possible, - # but can be an underestimate (e.g. when we have wild cards) - - member_filter, non_member_filter = state_filter.get_member_split() - if member_filter.is_full(): - # We fetched all member events - member_types = None - else: - # `concrete_types()` will only return a subset when there are wild - # cards in the filter, but that's fine. - member_types = member_filter.concrete_types() - - if non_member_filter.is_full(): - # We fetched all non member events - non_member_types = None - else: - non_member_types = non_member_filter.concrete_types() - - for group, group_state_dict in group_to_state_dict.items(): - state_dict_members = {} - state_dict_non_members = {} - - for k, v in group_state_dict.items(): - if k[0] == EventTypes.Member: - state_dict_members[k] = v - else: - state_dict_non_members[k] = v - - self._state_group_members_cache.update( - cache_seq_num_members, - key=group, - value=state_dict_members, - fetched_keys=member_types, - ) - - self._state_group_cache.update( - cache_seq_num_non_members, - key=group, - value=state_dict_non_members, - fetched_keys=non_member_types, - ) - - def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) - to event_id. - - Returns: - Deferred[int]: The state group ID - """ - - def _store_state_group_txn(txn): - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") - - state_group = self._state_group_seq_gen.get_next_id_txn(txn) - - self.db.simple_insert_txn( - txn, - table="state_groups", - values={"id": state_group, "room_id": room_id, "event_id": event_id}, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self.db.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in delta_ids.items() - ], - ) - else: - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in current_state_ids.items() - ], - ) - - # Prefill the state group caches with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - - current_member_state_ids = { - s: ev - for (s, ev) in current_state_ids.items() - if s[0] == EventTypes.Member - } - txn.call_after( - self._state_group_members_cache.update, - self._state_group_members_cache.sequence, - key=state_group, - value=dict(current_member_state_ids), - ) - - current_non_member_state_ids = { - s: ev - for (s, ev) in current_state_ids.items() - if s[0] != EventTypes.Member - } - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=state_group, - value=dict(current_non_member_state_ids), - ) - - return state_group - - return self.db.runInteraction("store_state_group", _store_state_group_txn) - - def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete - ) -> defer.Deferred: - """Deletes no longer referenced state groups and de-deltas any state - groups that reference them. - - Args: - room_id: The room the state groups belong to (must all be in the - same room). - state_groups_to_delete (Collection[int]): Set of all state groups - to delete. - """ - - return self.db.runInteraction( - "purge_unreferenced_state_groups", - self._purge_unreferenced_state_groups, - room_id, - state_groups_to_delete, - ) - - def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): - logger.info( - "[purge] found %i state groups to delete", len(state_groups_to_delete) - ) - - rows = self.db.simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=state_groups_to_delete, - keyvalues={}, - retcols=("state_group",), - ) - - remaining_state_groups = { - row["state_group"] - for row in rows - if row["state_group"] not in state_groups_to_delete - } - - logger.info( - "[purge] de-delta-ing %i remaining state groups", - len(remaining_state_groups), - ) - - # Now we turn the state groups that reference to-be-deleted state - # groups to non delta versions. - for sg in remaining_state_groups: - logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] - - self.db.simple_delete_txn( - txn, table="state_groups_state", keyvalues={"state_group": sg} - ) - - self.db.simple_delete_txn( - txn, table="state_group_edges", keyvalues={"state_group": sg} - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in curr_state.items() - ], - ) - - logger.info("[purge] removing redundant state groups") - txn.executemany( - "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), - ) - txn.executemany( - "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), - ) - - async def get_previous_state_groups( - self, state_groups: Iterable[int] - ) -> Dict[int, int]: - """Fetch the previous groups of the given state groups. - - Args: - state_groups - - Returns: - A mapping from state group to previous state group. - """ - - rows = await self.db.simple_select_many_batch( - table="state_group_edges", - column="prev_state_group", - iterable=state_groups, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - desc="get_previous_state_groups", - ) - - return {row["state_group"]: row["prev_state_group"] for row in rows} - - def purge_room_state(self, room_id, state_groups_to_delete): - """Deletes all record of a room from state tables - - Args: - room_id (str): - state_groups_to_delete (list[int]): State groups to delete - """ - - return self.db.runInteraction( - "purge_room_state", - self._purge_room_state_txn, - room_id, - state_groups_to_delete, - ) - - def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups_state", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_group_edges", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state groups - logger.info("[purge] removing %s from state_groups", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups", - column="id", - iterable=state_groups_to_delete, - keyvalues={}, - ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ce8757a400..4ada6f5563 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -279,7 +279,7 @@ class PerformanceCounters(object): return top_n_counters -class Database(object): +class DatabasePool(object): """Wraps a single physical database and connection pool. A single database may be used by multiple data stores. diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py new file mode 100644 index 0000000000..b163eebf39 --- /dev/null +++ b/synapse/storage/databases/__init__.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.events import PersistEventsStore +from synapse.storage.databases.state import StateGroupDataStore +from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database + +logger = logging.getLogger(__name__) + + +class Databases(object): + """The various databases. + + These are low level interfaces to physical databases. + + Attributes: + main (DataStore) + """ + + def __init__(self, main_store_class, hs): + # Note we pass in the main store class here as workers use a different main + # store. + + self.databases = [] + self.main = None + self.state = None + self.persist_events = None + + for database_config in hs.config.database.databases: + db_name = database_config.name + engine = create_engine(database_config.config) + + with make_conn(database_config, engine) as db_conn: + logger.info("Preparing database %r...", db_name) + + engine.check_database(db_conn) + prepare_database( + db_conn, engine, hs.config, databases=database_config.databases, + ) + + database = DatabasePool(hs, database_config, engine) + + if "main" in database_config.databases: + logger.info("Starting 'main' data store") + + # Sanity check we don't try and configure the main store on + # multiple databases. + if self.main: + raise Exception("'main' data store already configured") + + self.main = main_store_class(database, db_conn, hs) + + # If we're on a process that can persist events also + # instantiate a `PersistEventsStore` + if hs.config.worker.writers.events == hs.get_instance_name(): + self.persist_events = PersistEventsStore( + hs, database, self.main + ) + + if "state" in database_config.databases: + logger.info("Starting 'state' data store") + + # Sanity check we don't try and configure the state store on + # multiple databases. + if self.state: + raise Exception("'state' data store already configured") + + self.state = StateGroupDataStore(database, db_conn, hs) + + db_conn.commit() + + self.databases.append(database) + + logger.info("Database %r prepared", db_name) + + # Sanity check that we have actually configured all the required stores. + if not self.main: + raise Exception("No 'main' data store configured") + + if not self.state: + raise Exception("No 'main' data store configured") diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py new file mode 100644 index 0000000000..17fa470919 --- /dev/null +++ b/synapse/storage/databases/main/__init__.py @@ -0,0 +1,596 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import calendar +import logging +import time + +from synapse.api.constants import PresenceState +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import ( + IdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from .account_data import AccountDataStore +from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore +from .cache import CacheInvalidationWorkerStore +from .censor_events import CensorEventsStore +from .client_ips import ClientIpStore +from .deviceinbox import DeviceInboxStore +from .devices import DeviceStore +from .directory import DirectoryStore +from .e2e_room_keys import EndToEndRoomKeyStore +from .end_to_end_keys import EndToEndKeyStore +from .event_federation import EventFederationStore +from .event_push_actions import EventPushActionsStore +from .events_bg_updates import EventsBackgroundUpdatesStore +from .filtering import FilteringStore +from .group_server import GroupServerStore +from .keys import KeyStore +from .media_repository import MediaRepositoryStore +from .metrics import ServerMetricsStore +from .monthly_active_users import MonthlyActiveUsersStore +from .openid import OpenIdStore +from .presence import PresenceStore, UserPresenceState +from .profile import ProfileStore +from .purge_events import PurgeEventsStore +from .push_rule import PushRuleStore +from .pusher import PusherStore +from .receipts import ReceiptsStore +from .registration import RegistrationStore +from .rejections import RejectionsStore +from .relations import RelationsStore +from .room import RoomStore +from .roommember import RoomMemberStore +from .search import SearchStore +from .signatures import SignatureStore +from .state import StateStore +from .stats import StatsStore +from .stream import StreamStore +from .tags import TagsStore +from .transactions import TransactionStore +from .ui_auth import UIAuthStore +from .user_directory import UserDirectoryStore +from .user_erasure_store import UserErasureStore + +logger = logging.getLogger(__name__) + + +class DataStore( + EventsBackgroundUpdatesStore, + RoomMemberStore, + RoomStore, + RegistrationStore, + StreamStore, + ProfileStore, + PresenceStore, + TransactionStore, + DirectoryStore, + KeyStore, + StateStore, + SignatureStore, + ApplicationServiceStore, + PurgeEventsStore, + EventFederationStore, + MediaRepositoryStore, + RejectionsStore, + FilteringStore, + PusherStore, + PushRuleStore, + ApplicationServiceTransactionStore, + ReceiptsStore, + EndToEndKeyStore, + EndToEndRoomKeyStore, + SearchStore, + TagsStore, + AccountDataStore, + EventPushActionsStore, + OpenIdStore, + ClientIpStore, + DeviceStore, + DeviceInboxStore, + UserDirectoryStore, + GroupServerStore, + UserErasureStore, + MonthlyActiveUsersStore, + StatsStore, + RelationsStore, + CensorEventsStore, + UIAuthStore, + CacheInvalidationWorkerStore, + ServerMetricsStore, +): + def __init__(self, database: DatabasePool, db_conn, hs): + self.hs = hs + self._clock = hs.get_clock() + self.database_engine = database.engine + + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) + self._device_inbox_id_gen = StreamIdGenerator( + db_conn, "device_inbox", "stream_id" + ) + self._public_room_id_gen = StreamIdGenerator( + db_conn, "public_room_list_stream", "stream_id" + ) + self._device_list_id_gen = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ], + ) + self._cross_signing_id_gen = StreamIdGenerator( + db_conn, "e2e_cross_signing_keys", "stream_id" + ) + + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + self._pushers_id_gen = StreamIdGenerator( + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] + ) + self._group_updates_id_gen = StreamIdGenerator( + db_conn, "local_group_updates", "stream_id" + ) + + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + instance_name="master", + table="cache_invalidation_stream_by_instance", + instance_column="instance_name", + id_column="stream_id", + sequence_name="cache_invalidation_stream_seq", + ) + else: + self._cache_id_gen = None + + super(DataStore, self).__init__(database, db_conn, hs) + + self._presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( + db_conn, + "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_current_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", + min_presence_val, + prefilled_cache=presence_cache_prefill, + ) + + max_device_inbox_id = self._device_inbox_id_gen.get_current_token() + device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( + db_conn, + "device_inbox", + entity_column="user_id", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", + min_device_inbox_id, + prefilled_cache=device_inbox_prefill, + ) + # The federation outbox and the local device inbox uses the same + # stream_id generator. + device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( + db_conn, + "device_federation_outbox", + entity_column="destination", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", + min_device_outbox_id, + prefilled_cache=device_outbox_prefill, + ) + + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max + ) + self._user_signature_stream_cache = StreamChangeCache( + "UserSignatureStreamChangeCache", device_list_max + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max + ) + + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( + db_conn, + "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) + + _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict( + db_conn, + "local_group_updates", + entity_column="user_id", + stream_column="stream_id", + max_value=self._group_updates_id_gen.get_current_token(), + limit=1000, + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", + min_group_updates_id, + prefilled_cache=_group_updates_prefill, + ) + + self._stream_order_on_start = self.get_room_max_stream_ordering() + self._min_stream_order_on_start = self.get_room_min_stream_ordering() + + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + + def take_presence_startup_info(self): + active_on_startup = self._presence_on_startup + self._presence_on_startup = None + return active_on_startup + + def _get_active_presence(self, db_conn): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active_ts, last_federation_update_ts," + " last_user_sync_ts, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.db_pool.cursor_to_dict(txn) + txn.close() + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + + def count_daily_users(self): + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + return self.db_pool.runInteraction( + "count_daily_users", self._count_users, yesterday + ) + + def count_monthly_users(self): + """ + Counts the number of users who used this homeserver in the last 30 days. + Note this method is intended for phonehome metrics only and is different + from the mau figure in synapse.storage.monthly_active_users which, + amongst other things, includes a 3 day grace period before a user counts. + """ + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + return self.db_pool.runInteraction( + "count_monthly_users", self._count_users, thirty_days_ago + ) + + def _count_users(self, txn, time_from): + """ + Returns number of users seen in the past time_from period + """ + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + txn.execute(sql, (time_from,)) + (count,) = txn.fetchone() + return count + + def count_r30_users(self): + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns counts globaly for a given user as well as breaking + by platform + """ + + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + for row in txn: + if row[0] == "unknown": + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + (count,) = txn.fetchone() + results["all"] = count + + return results + + return self.db_pool.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = time.gmtime() + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) + return today_start * 1000 + + def generate_user_daily_visits(self): + """ + Generates daily visit data for use in cohort/ retention analysis + """ + + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self.clock.time_msec() + + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp) + SELECT u.user_id, u.device_id, ? + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) + self._last_user_visit_update = today_start + + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + return self.db_pool.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) + + def get_users(self): + """Function to retrieve a list of users in users table. + + Args: + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.db_pool.simple_select_list( + table="users", + keyvalues={}, + retcols=[ + "name", + "password_hash", + "is_guest", + "admin", + "user_type", + "deactivated", + ], + desc="get_users", + ) + + def get_users_paginate( + self, start, limit, name=None, guests=True, deactivated=False + ): + """Function to retrieve a paginated list of users from + users list. This will return a json list of users and the + total number of users matching the filter criteria. + + Args: + start (int): start number to begin the query from + limit (int): number of rows to retrieve + name (string): filter for user names + guests (bool): whether to in include guest users + deactivated (bool): whether to include deactivated users + Returns: + defer.Deferred: resolves to list[dict[str, Any]], int + """ + + def get_users_paginate_txn(txn): + filters = [] + args = [] + + if name: + filters.append("name LIKE ?") + args.append("%" + name + "%") + + if not guests: + filters.append("is_guest = 0") + + if not deactivated: + filters.append("deactivated = 0") + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause) + txn.execute(sql, args) + count = txn.fetchone()[0] + + args = [self.hs.config.server_name] + args + [limit, start] + sql = """ + SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url + FROM users as u + LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? + {} + ORDER BY u.name LIMIT ? OFFSET ? + """.format( + where_clause + ) + txn.execute(sql, args) + users = self.db_pool.cursor_to_dict(txn) + return users, count + + return self.db_pool.runInteraction( + "get_users_paginate_txn", get_users_paginate_txn + ) + + def search_users(self, term): + """Function to search users list for one or more users with + the matched term. + + Args: + term (str): search term + col (str): column to query term should be matched to + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.db_pool.simple_search_list( + table="users", + term=term, + col="name", + retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + desc="search_users", + ) + + +def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig): + """Called before upgrading an existing database to check that it is broadly sane + compared with the configuration. + """ + domain = config.server_name + + sql = database_engine.convert_param_style( + "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" + ) + pat = "%:" + domain + cur.execute(sql, (pat,)) + num_not_matching = cur.fetchall()[0][0] + if num_not_matching == 0: + return + + raise Exception( + "Found users in database not native to %s!\n" + "You cannot changed a synapse server_name after it's been configured" + % (domain,) + ) + + +__all__ = ["DataStore", "check_database_before_upgrade"] diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py new file mode 100644 index 0000000000..2193d8fdc5 --- /dev/null +++ b/synapse/storage/databases/main/account_data.py @@ -0,0 +1,430 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging +from typing import List, Tuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +class AccountDataWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_account_data_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, database: DatabasePool, db_conn, hs): + account_max = self.get_max_account_data_stream_id() + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", account_max + ) + + super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) + + @abc.abstractmethod + def get_max_account_data_stream_id(self): + """Get the current max stream ID for account data stream + + Returns: + int + """ + raise NotImplementedError() + + @cached() + def get_account_data_for_user(self, user_id): + """Get all the client account_data for a user. + + Args: + user_id(str): The user to get the account_data for. + Returns: + A deferred pair of a dict of global account_data and a dict + mapping from room_id string to per room account_data dicts. + """ + + def get_account_data_for_user_txn(txn): + rows = self.db_pool.simple_select_list_txn( + txn, + "account_data", + {"user_id": user_id}, + ["account_data_type", "content"], + ) + + global_account_data = { + row["account_data_type"]: db_to_json(row["content"]) for row in rows + } + + rows = self.db_pool.simple_select_list_txn( + txn, + "room_account_data", + {"user_id": user_id}, + ["room_id", "account_data_type", "content"], + ) + + by_room = {} + for row in rows: + room_data = by_room.setdefault(row["room_id"], {}) + room_data[row["account_data_type"]] = db_to_json(row["content"]) + + return global_account_data, by_room + + return self.db_pool.runInteraction( + "get_account_data_for_user", get_account_data_for_user_txn + ) + + @cachedInlineCallbacks(num_args=2, max_entries=5000) + def get_global_account_data_by_type_for_user(self, data_type, user_id): + """ + Returns: + Deferred: A dict + """ + result = yield self.db_pool.simple_select_one_onecol( + table="account_data", + keyvalues={"user_id": user_id, "account_data_type": data_type}, + retcol="content", + desc="get_global_account_data_by_type_for_user", + allow_none=True, + ) + + if result: + return db_to_json(result) + else: + return None + + @cached(num_args=2) + def get_account_data_for_room(self, user_id, room_id): + """Get all the client account_data for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + Returns: + A deferred dict of the room account_data + """ + + def get_account_data_for_room_txn(txn): + rows = self.db_pool.simple_select_list_txn( + txn, + "room_account_data", + {"user_id": user_id, "room_id": room_id}, + ["account_data_type", "content"], + ) + + return { + row["account_data_type"]: db_to_json(row["content"]) for row in rows + } + + return self.db_pool.runInteraction( + "get_account_data_for_room", get_account_data_for_room_txn + ) + + @cached(num_args=3, max_entries=5000) + def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): + """Get the client account_data of given type for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + account_data_type (str): The account data type to get. + Returns: + A deferred of the room account_data for that type, or None if + there isn't any set. + """ + + def get_account_data_for_room_and_type_txn(txn): + content_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + retcol="content", + allow_none=True, + ) + + return db_to_json(content_json) if content_json else None + + return self.db_pool.runInteraction( + "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn + ) + + async def get_updated_global_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str]]: + """Get the global account_data that has changed, for the account_data stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: + A list of tuples of stream_id int, user_id string, + and type string. + """ + if last_id == current_id: + return [] + + def get_updated_global_account_data_txn(txn): + sql = ( + "SELECT stream_id, user_id, account_data_type" + " FROM account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return await self.db_pool.runInteraction( + "get_updated_global_account_data", get_updated_global_account_data_txn + ) + + async def get_updated_room_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str, str]]: + """Get the global account_data that has changed, for the account_data stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: + A list of tuples of stream_id int, user_id string, + room_id string and type string. + """ + if last_id == current_id: + return [] + + def get_updated_room_account_data_txn(txn): + sql = ( + "SELECT stream_id, user_id, room_id, account_data_type" + " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return await self.db_pool.runInteraction( + "get_updated_room_account_data", get_updated_room_account_data_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed for a user + + Args: + user_id(str): The user to get the account_data for. + stream_id(int): The point in the stream since which to get updates + Returns: + A deferred pair of a dict of global account_data and a dict + mapping from room_id string to per room account_data dicts. + """ + + def get_updated_account_data_for_user_txn(txn): + sql = ( + "SELECT account_data_type, content FROM account_data" + " WHERE user_id = ? AND stream_id > ?" + ) + + txn.execute(sql, (user_id, stream_id)) + + global_account_data = {row[0]: db_to_json(row[1]) for row in txn} + + sql = ( + "SELECT room_id, account_data_type, content FROM room_account_data" + " WHERE user_id = ? AND stream_id > ?" + ) + + txn.execute(sql, (user_id, stream_id)) + + account_data_by_room = {} + for row in txn: + room_account_data = account_data_by_room.setdefault(row[0], {}) + room_account_data[row[1]] = db_to_json(row[2]) + + return global_account_data, account_data_by_room + + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(stream_id) + ) + if not changed: + return defer.succeed(({}, {})) + + return self.db_pool.runInteraction( + "get_updated_account_data_for_user", get_updated_account_data_for_user_txn + ) + + @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) + def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): + ignored_account_data = yield self.get_global_account_data_by_type_for_user( + "m.ignored_user_list", + ignorer_user_id, + on_invalidate=cache_context.invalidate, + ) + if not ignored_account_data: + return False + + return ignored_user_id in ignored_account_data.get("ignored_users", {}) + + +class AccountDataStore(AccountDataWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "account_data_max_stream_id", + "stream_id", + extra_tables=[ + ("room_account_data", "stream_id"), + ("room_tags_revisions", "stream_id"), + ], + ) + + super(AccountDataStore, self).__init__(database, db_conn, hs) + + def get_max_account_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._account_data_id_gen.get_current_token() + + @defer.inlineCallbacks + def add_account_data_to_room(self, user_id, room_id, account_data_type, content): + """Add some account_data to a room for a user. + Args: + user_id(str): The user to add a tag for. + room_id(str): The room to add a tag for. + account_data_type(str): The type of account_data to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the account_data has been added. + """ + content_json = json.dumps(content) + + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as room_account_data has a unique constraint + # on (user_id, room_id, account_data_type) so simple_upsert will + # retry if there is a conflict. + yield self.db_pool.simple_upsert( + desc="add_room_account_data", + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + values={"stream_id": next_id, "content": content_json}, + lock=False, + ) + + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type), content + ) + + result = self._account_data_id_gen.get_current_token() + return result + + @defer.inlineCallbacks + def add_account_data_for_user(self, user_id, account_data_type, content): + """Add some account_data to a room for a user. + Args: + user_id(str): The user to add a tag for. + account_data_type(str): The type of account_data to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the account_data has been added. + """ + content_json = json.dumps(content) + + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as account_data has a unique constraint on + # (user_id, account_data_type) so simple_upsert will retry if + # there is a conflict. + yield self.db_pool.simple_upsert( + desc="add_user_account_data", + table="account_data", + keyvalues={"user_id": user_id, "account_data_type": account_data_type}, + values={"stream_id": next_id, "content": content_json}, + lock=False, + ) + + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + # + # Note: This is only here for backwards compat to allow admins to + # roll back to a previous Synapse version. Next time we update the + # database version we can remove this table. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.invalidate( + (account_data_type, user_id) + ) + + result = self._account_data_id_gen.get_current_token() + return result + + def _update_max_stream_id(self, next_id): + """Update the max stream_id + + Args: + next_id(int): The the revision to advance to. + """ + + # Note: This is only here for backwards compat to allow admins to + # roll back to a previous Synapse version. Next time we update the + # database version we can remove this table. + + def _update(txn): + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + + return self.db_pool.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py new file mode 100644 index 0000000000..055a3962dc --- /dev/null +++ b/synapse/storage/databases/main/appservice.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.appservice import AppServiceTransaction +from synapse.config.appservice import load_appservices +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore + +logger = logging.getLogger(__name__) + + +def _make_exclusive_regex(services_cache): + # We precompile a regex constructed from all the regexes that the AS's + # have registered for exclusive users. + exclusive_user_regexes = [ + regex.pattern + for service in services_cache + for regex in service.get_exclusive_user_regexes() + ] + if exclusive_user_regexes: + exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) + exclusive_user_regex = re.compile(exclusive_user_regex) + else: + # We handle this case specially otherwise the constructed regex + # will always match + exclusive_user_regex = None + + return exclusive_user_regex + + +class ApplicationServiceWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + self.services_cache = load_appservices( + hs.hostname, hs.config.app_service_config_files + ) + self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) + + super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) + + def get_app_services(self): + return self.services_cache + + def get_if_app_services_interested_in_user(self, user_id): + """Check if the user is one associated with an app service (exclusively) + """ + if self.exclusive_user_regex: + return bool(self.exclusive_user_regex.match(user_id)) + else: + return False + + def get_app_service_by_user_id(self, user_id): + """Retrieve an application service from their user ID. + + All application services have associated with them a particular user ID. + There is no distinguishing feature on the user ID which indicates it + represents an application service. This function allows you to map from + a user ID to an application service. + + Args: + user_id(str): The user ID to see if it is an application service. + Returns: + synapse.appservice.ApplicationService or None. + """ + for service in self.services_cache: + if service.sender == user_id: + return service + return None + + def get_app_service_by_token(self, token): + """Get the application service with the given appservice token. + + Args: + token (str): The application service token. + Returns: + synapse.appservice.ApplicationService or None. + """ + for service in self.services_cache: + if service.token == token: + return service + return None + + def get_app_service_by_id(self, as_id): + """Get the application service with the given appservice ID. + + Args: + as_id (str): The application service ID. + Returns: + synapse.appservice.ApplicationService or None. + """ + for service in self.services_cache: + if service.id == as_id: + return service + return None + + +class ApplicationServiceStore(ApplicationServiceWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass + + +class ApplicationServiceTransactionWorkerStore( + ApplicationServiceWorkerStore, EventsWorkerStore +): + @defer.inlineCallbacks + def get_appservices_by_state(self, state): + """Get a list of application services based on their state. + + Args: + state(ApplicationServiceState): The state to filter on. + Returns: + A Deferred which resolves to a list of ApplicationServices, which + may be empty. + """ + results = yield self.db_pool.simple_select_list( + "application_services_state", {"state": state}, ["as_id"] + ) + # NB: This assumes this class is linked with ApplicationServiceStore + as_list = self.get_app_services() + services = [] + + for res in results: + for service in as_list: + if service.id == res["as_id"]: + services.append(service) + return services + + @defer.inlineCallbacks + def get_appservice_state(self, service): + """Get the application service state. + + Args: + service(ApplicationService): The service whose state to set. + Returns: + A Deferred which resolves to ApplicationServiceState. + """ + result = yield self.db_pool.simple_select_one( + "application_services_state", + {"as_id": service.id}, + ["state"], + allow_none=True, + desc="get_appservice_state", + ) + if result: + return result.get("state") + return None + + def set_appservice_state(self, service, state): + """Set the application service state. + + Args: + service(ApplicationService): The service whose state to set. + state(ApplicationServiceState): The connectivity state to apply. + Returns: + A Deferred which resolves when the state was set successfully. + """ + return self.db_pool.simple_upsert( + "application_services_state", {"as_id": service.id}, {"state": state} + ) + + def create_appservice_txn(self, service, events): + """Atomically creates a new transaction for this application service + with the given list of events. + + Args: + service(ApplicationService): The service who the transaction is for. + events(list): A list of events to put in the transaction. + Returns: + AppServiceTransaction: A new transaction. + """ + + def _create_appservice_txn(txn): + # work out new txn id (highest txn id for this service += 1) + # The highest id may be the last one sent (in which case it is last_txn) + # or it may be the highest in the txns list (which are waiting to be/are + # being sent) + last_txn_id = self._get_last_txn(txn, service.id) + + txn.execute( + "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", + (service.id,), + ) + highest_txn_id = txn.fetchone()[0] + if highest_txn_id is None: + highest_txn_id = 0 + + new_txn_id = max(highest_txn_id, last_txn_id) + 1 + + # Insert new txn into txn table + event_ids = json.dumps([e.event_id for e in events]) + txn.execute( + "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " + "VALUES(?,?,?)", + (service.id, new_txn_id, event_ids), + ) + return AppServiceTransaction(service=service, id=new_txn_id, events=events) + + return self.db_pool.runInteraction( + "create_appservice_txn", _create_appservice_txn + ) + + def complete_appservice_txn(self, txn_id, service): + """Completes an application service transaction. + + Args: + txn_id(str): The transaction ID being completed. + service(ApplicationService): The application service which was sent + this transaction. + Returns: + A Deferred which resolves if this transaction was stored + successfully. + """ + txn_id = int(txn_id) + + def _complete_appservice_txn(txn): + # Debugging query: Make sure the txn being completed is EXACTLY +1 from + # what was there before. If it isn't, we've got problems (e.g. the AS + # has probably missed some events), so whine loudly but still continue, + # since it shouldn't fail completion of the transaction. + last_txn_id = self._get_last_txn(txn, service.id) + if (last_txn_id + 1) != txn_id: + logger.error( + "appservice: Completing a transaction which has an ID > 1 from " + "the last ID sent to this AS. We've either dropped events or " + "sent it to the AS out of order. FIX ME. last_txn=%s " + "completing_txn=%s service_id=%s", + last_txn_id, + txn_id, + service.id, + ) + + # Set current txn_id for AS to 'txn_id' + self.db_pool.simple_upsert_txn( + txn, + "application_services_state", + {"as_id": service.id}, + {"last_txn": txn_id}, + ) + + # Delete txn + self.db_pool.simple_delete_txn( + txn, + "application_services_txns", + {"txn_id": txn_id, "as_id": service.id}, + ) + + return self.db_pool.runInteraction( + "complete_appservice_txn", _complete_appservice_txn + ) + + @defer.inlineCallbacks + def get_oldest_unsent_txn(self, service): + """Get the oldest transaction which has not been sent for this + service. + + Args: + service(ApplicationService): The app service to get the oldest txn. + Returns: + A Deferred which resolves to an AppServiceTransaction or + None. + """ + + def _get_oldest_unsent_txn(txn): + # Monotonically increasing txn ids, so just select the smallest + # one in the txns table (we delete them when they are sent) + txn.execute( + "SELECT * FROM application_services_txns WHERE as_id=?" + " ORDER BY txn_id ASC LIMIT 1", + (service.id,), + ) + rows = self.db_pool.cursor_to_dict(txn) + if not rows: + return None + + entry = rows[0] + + return entry + + entry = yield self.db_pool.runInteraction( + "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn + ) + + if not entry: + return None + + event_ids = db_to_json(entry["event_ids"]) + + events = yield self.get_events_as_list(event_ids) + + return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) + + def _get_last_txn(self, txn, service_id): + txn.execute( + "SELECT last_txn FROM application_services_state WHERE as_id=?", + (service_id,), + ) + last_txn_id = txn.fetchone() + if last_txn_id is None or last_txn_id[0] is None: # no row exists + return 0 + else: + return int(last_txn_id[0]) # select 'last_txn' col + + def set_appservice_last_pos(self, pos): + def set_appservice_last_pos_txn(txn): + txn.execute( + "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) + ) + + return self.db_pool.runInteraction( + "set_appservice_last_pos", set_appservice_last_pos_txn + ) + + @defer.inlineCallbacks + def get_new_events_for_appservice(self, current_id, limit): + """Get all new evnets""" + + def get_new_events_for_appservice_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id" + " FROM events AS e" + " WHERE" + " (SELECT stream_ordering FROM appservice_stream_position)" + " < e.stream_ordering" + " AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute(sql, (current_id, limit)) + rows = txn.fetchall() + + upper_bound = current_id + if len(rows) == limit: + upper_bound = rows[-1][0] + + return upper_bound, [row[1] for row in rows] + + upper_bound, event_ids = yield self.db_pool.runInteraction( + "get_new_events_for_appservice", get_new_events_for_appservice_txn + ) + + events = yield self.get_events_as_list(event_ids) + + return upper_bound, events + + +class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py new file mode 100644 index 0000000000..683afde52b --- /dev/null +++ b/synapse/storage/databases/main/cache.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging +from typing import Any, Iterable, List, Optional, Tuple + +from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams import BackfillStream, CachesStream +from synapse.replication.tcp.streams.events import ( + EventsStream, + EventsStreamCurrentStateRow, + EventsStreamEventRow, +) +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.util.iterutils import batch_iter + +logger = logging.getLogger(__name__) + + +# This is a special cache name we use to batch multiple invalidations of caches +# based on the current state when notifying workers over replication. +CURRENT_STATE_CACHE_NAME = "cs_cache_fake" + + +class CacheInvalidationWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._instance_name = hs.get_instance_name() + + async def get_all_updated_caches( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for caches replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = """ + SELECT stream_id, cache_func, keys, invalidation_ts + FROM cache_invalidation_stream_by_instance + WHERE stream_id > ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, instance_name, limit)) + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == EventsStream.NAME: + for row in rows: + self._process_event_stream_row(token, row) + elif stream_name == BackfillStream.NAME: + for row in rows: + self._invalidate_caches_for_event( + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, + backfilled=True, + ) + elif stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + + for row in rows: + if row.cache_func == CURRENT_STATE_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for current state cache" + ) + + room_id = row.keys[0] + members_changed = set(row.keys[1:]) + self._invalidate_state_caches(room_id, members_changed) + else: + self._attempt_to_invalidate_cache(row.cache_func, row.keys) + + super().process_replication_rows(stream_name, instance_name, token, rows) + + def _process_event_stream_row(self, token, row): + data = row.data + + if row.type == EventsStreamEventRow.TypeId: + self._invalidate_caches_for_event( + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, + backfilled=False, + ) + elif row.type == EventsStreamCurrentStateRow.TypeId: + self._curr_state_delta_stream_cache.entity_has_changed( + row.data.room_id, token + ) + + if data.type == EventTypes.Member: + self.get_rooms_for_user_with_stream_ordering.invalidate( + (data.state_key,) + ) + else: + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def _invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): + self._invalidate_get_event_cache(event_id) + + self.get_latest_event_ids_in_room.invalidate((room_id,)) + + self.get_unread_message_count_for_user.invalidate_many((room_id,)) + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) + + if not backfilled: + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) + + if redacts: + self._invalidate_get_event_cache(redacts) + + if etype == EventTypes.Member: + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) + self.get_invited_rooms_for_local_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) + + async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + cache_func = getattr(self, cache_name, None) + if not cache_func: + return + + cache_func.invalidate(keys) + await self.db_pool.runInteraction( + "invalidate_cache_and_stream", + self._send_invalidation_to_replication, + cache_func.__name__, + keys, + ) + + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + txn.call_after(cache_func.invalidate, keys) + self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + + def _invalidate_all_cache_and_stream(self, txn, cache_func): + """Invalidates the entire cache and adds it to the cache stream so slaves + will know to invalidate their caches. + """ + + txn.call_after(cache_func.invalidate_all) + self._send_invalidation_to_replication(txn, cache_func.__name__, None) + + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + """Special case invalidation of caches based on current state. + + We special case this so that we can batch the cache invalidations into a + single replication poke. + + Args: + txn + room_id (str): Room where state changed + members_changed (iterable[str]): The user_ids of members that have changed + """ + txn.call_after(self._invalidate_state_caches, room_id, members_changed) + + if members_changed: + # We need to be careful that the size of the `members_changed` list + # isn't so large that it causes problems sending over replication, so we + # send them in chunks. + # Max line length is 16K, and max user ID length is 255, so 50 should + # be safe. + for chunk in batch_iter(members_changed, 50): + keys = itertools.chain([room_id], chunk) + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, keys + ) + else: + # if no members changed, we still need to invalidate the other caches. + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, [room_id] + ) + + def _send_invalidation_to_replication( + self, txn, cache_name: str, keys: Optional[Iterable[Any]] + ): + """Notifies replication that given cache has been invalidated. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name + keys: Entry to invalidate. If None will invalidate all. + """ + + if cache_name == CURRENT_STATE_CACHE_NAME and keys is None: + raise Exception( + "Can't stream invalidate all with magic current state cache" + ) + + if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. + stream_id = self._cache_id_gen.get_next_txn(txn) + txn.call_after(self.hs.get_notifier().on_new_replication_data) + + if keys is not None: + keys = list(keys) + + self.db_pool.simple_insert_txn( + txn, + table="cache_invalidation_stream_by_instance", + values={ + "stream_id": stream_id, + "instance_name": self._instance_name, + "cache_func": cache_name, + "keys": keys, + "invalidation_ts": self.clock.time_msec(), + }, + ) + + def get_cache_stream_token(self, instance_name): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token(instance_name) + else: + return 0 diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py new file mode 100644 index 0000000000..1de8249563 --- /dev/null +++ b/synapse/storage/databases/main/censor_events.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from twisted.internet import defer + +from synapse.events.utils import prune_event_dict +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.databases.main.events import encode_json +from synapse.storage.databases.main.events_worker import EventsWorkerStore + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + + def _censor_redactions(): + return run_as_background_process( + "_censor_redactions", self._censor_redactions + ) + + if self.hs.config.redaction_retention_period is not None: + hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + + async def _censor_redactions(self): + """Censors all redactions older than the configured period that haven't + been censored yet. + + By censor we mean update the event_json table with the redacted event. + """ + + if self.hs.config.redaction_retention_period is None: + return + + if not ( + await self.db_pool.updates.has_completed_background_update( + "redactions_have_censored_ts_idx" + ) + ): + # We don't want to run this until the appropriate index has been + # created. + return + + before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + + # We fetch all redactions that: + # 1. point to an event we have, + # 2. has a received_ts from before the cut off, and + # 3. we haven't yet censored. + # + # This is limited to 100 events to ensure that we don't try and do too + # much at once. We'll get called again so this should eventually catch + # up. + sql = """ + SELECT redactions.event_id, redacts FROM redactions + LEFT JOIN events AS original_event ON ( + redacts = original_event.event_id + ) + WHERE NOT have_censored + AND redactions.received_ts <= ? + ORDER BY redactions.received_ts ASC + LIMIT ? + """ + + rows = await self.db_pool.execute( + "_censor_redactions_fetch", None, sql, before_ts, 100 + ) + + updates = [] + + for redaction_id, event_id in rows: + redaction_event = await self.get_event(redaction_id, allow_none=True) + original_event = await self.get_event( + event_id, allow_rejected=True, allow_none=True + ) + + # The SQL above ensures that we have both the redaction and + # original event, so if the `get_event` calls return None it + # means that the redaction wasn't allowed. Either way we know that + # the result won't change so we mark the fact that we've checked. + if ( + redaction_event + and original_event + and original_event.internal_metadata.is_redacted() + ): + # Redaction was allowed + pruned_json = encode_json( + prune_event_dict( + original_event.room_version, original_event.get_dict() + ) + ) + else: + # Redaction wasn't allowed + pruned_json = None + + updates.append((redaction_id, event_id, pruned_json)) + + def _update_censor_txn(txn): + for redaction_id, event_id, pruned_json in updates: + if pruned_json: + self._censor_event_txn(txn, event_id, pruned_json) + + self.db_pool.simple_update_one_txn( + txn, + table="redactions", + keyvalues={"event_id": redaction_id}, + updatevalues={"have_censored": True}, + ) + + await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn) + + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self.db_pool.simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json( + prune_event_dict(event.room_version, event.get_dict()) + ) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.db_pool.runInteraction( + "delete_expired_event", delete_expired_event_txn + ) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self.db_pool.simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py new file mode 100644 index 0000000000..712c8d0264 --- /dev/null +++ b/synapse/storage/databases/main/client_ips.py @@ -0,0 +1,580 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, make_tuple_comparison_clause +from synapse.util.caches.descriptors import Cache + +logger = logging.getLogger(__name__) + +# Number of msec of granularity to store the user IP 'last seen' time. Smaller +# times give more inserts into the database even for readonly API hits +# 120 seconds == 2 minutes +LAST_SEEN_GRANULARITY = 120 * 1000 + + +class ClientIpBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "user_ips_device_index", + index_name="user_ips_device_id", + table="user_ips", + columns=["user_id", "device_id", "last_seen"], + ) + + self.db_pool.updates.register_background_index_update( + "user_ips_last_seen_index", + index_name="user_ips_last_seen", + table="user_ips", + columns=["user_id", "last_seen"], + ) + + self.db_pool.updates.register_background_index_update( + "user_ips_last_seen_only_index", + index_name="user_ips_last_seen_only", + table="user_ips", + columns=["last_seen"], + ) + + self.db_pool.updates.register_background_update_handler( + "user_ips_analyze", self._analyze_user_ip + ) + + self.db_pool.updates.register_background_update_handler( + "user_ips_remove_dupes", self._remove_user_ip_dupes + ) + + # Register a unique index + self.db_pool.updates.register_background_index_update( + "user_ips_device_unique_index", + index_name="user_ips_user_token_ip_unique_index", + table="user_ips", + columns=["user_id", "access_token", "ip"], + unique=True, + ) + + # Drop the old non-unique index + self.db_pool.updates.register_background_update_handler( + "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique + ) + + # Update the last seen info in devices. + self.db_pool.updates.register_background_update_handler( + "devices_last_seen", self._devices_last_seen_update + ) + + @defer.inlineCallbacks + def _remove_user_ip_nonunique(self, progress, batch_size): + def f(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") + txn.close() + + yield self.db_pool.runWithConnection(f) + yield self.db_pool.updates._end_background_update( + "user_ips_drop_nonunique_index" + ) + return 1 + + @defer.inlineCallbacks + def _analyze_user_ip(self, progress, batch_size): + # Background update to analyze user_ips table before we run the + # deduplication background update. The table may not have been analyzed + # for ages due to the table locks. + # + # This will lock out the naive upserts to user_ips while it happens, but + # the analyze should be quick (28GB table takes ~10s) + def user_ips_analyze(txn): + txn.execute("ANALYZE user_ips") + + yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) + + yield self.db_pool.updates._end_background_update("user_ips_analyze") + + return 1 + + @defer.inlineCallbacks + def _remove_user_ip_dupes(self, progress, batch_size): + # This works function works by scanning the user_ips table in batches + # based on `last_seen`. For each row in a batch it searches the rest of + # the table to see if there are any duplicates, if there are then they + # are removed and replaced with a suitable row. + + # Fetch the start of the batch + begin_last_seen = progress.get("last_seen", 0) + + def get_last_seen(txn): + txn.execute( + """ + SELECT last_seen FROM user_ips + WHERE last_seen > ? + ORDER BY last_seen + LIMIT 1 + OFFSET ? + """, + (begin_last_seen, batch_size), + ) + row = txn.fetchone() + if row: + return row[0] + else: + return None + + # Get a last seen that has roughly `batch_size` since `begin_last_seen` + end_last_seen = yield self.db_pool.runInteraction( + "user_ips_dups_get_last_seen", get_last_seen + ) + + # If it returns None, then we're processing the last batch + last = end_last_seen is None + + logger.info( + "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s", + begin_last_seen, + end_last_seen, + ) + + def remove(txn): + # This works by looking at all entries in the given time span, and + # then for each (user_id, access_token, ip) tuple in that range + # checking for any duplicates in the rest of the table (via a join). + # It then only returns entries which have duplicates, and the max + # last_seen across all duplicates, which can the be used to delete + # all other duplicates. + # It is efficient due to the existence of (user_id, access_token, + # ip) and (last_seen) indices. + + # Define the search space, which requires handling the last batch in + # a different way + if last: + clause = "? <= last_seen" + args = (begin_last_seen,) + else: + clause = "? <= last_seen AND last_seen < ?" + args = (begin_last_seen, end_last_seen) + + # (Note: The DISTINCT in the inner query is important to ensure that + # the COUNT(*) is accurate, otherwise double counting may happen due + # to the join effectively being a cross product) + txn.execute( + """ + SELECT user_id, access_token, ip, + MAX(device_id), MAX(user_agent), MAX(last_seen), + COUNT(*) + FROM ( + SELECT DISTINCT user_id, access_token, ip + FROM user_ips + WHERE {} + ) c + INNER JOIN user_ips USING (user_id, access_token, ip) + GROUP BY user_id, access_token, ip + HAVING count(*) > 1 + """.format( + clause + ), + args, + ) + res = txn.fetchall() + + # We've got some duplicates + for i in res: + user_id, access_token, ip, device_id, user_agent, last_seen, count = i + + # We want to delete the duplicates so we end up with only a + # single row. + # + # The naive way of doing this would be just to delete all rows + # and reinsert a constructed row. However, if there are a lot of + # duplicate rows this can cause the table to grow a lot, which + # can be problematic in two ways: + # 1. If user_ips is already large then this can cause the + # table to rapidly grow, potentially filling the disk. + # 2. Reinserting a lot of rows can confuse the table + # statistics for postgres, causing it to not use the + # correct indices for the query above, resulting in a full + # table scan. This is incredibly slow for large tables and + # can kill database performance. (This seems to mainly + # happen for the last query where the clause is simply `? < + # last_seen`) + # + # So instead we want to delete all but *one* of the duplicate + # rows. That is hard to do reliably, so we cheat and do a two + # step process: + # 1. Delete all rows with a last_seen strictly less than the + # max last_seen. This hopefully results in deleting all but + # one row the majority of the time, but there may be + # duplicate last_seen + # 2. If multiple rows remain, we fall back to the naive method + # and simply delete all rows and reinsert. + # + # Note that this relies on no new duplicate rows being inserted, + # but if that is happening then this entire process is futile + # anyway. + + # Do step 1: + + txn.execute( + """ + DELETE FROM user_ips + WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ? + """, + (user_id, access_token, ip, last_seen), + ) + if txn.rowcount == count - 1: + # We deleted all but one of the duplicate rows, i.e. there + # is exactly one remaining and so there is nothing left to + # do. + continue + elif txn.rowcount >= count: + raise Exception( + "We deleted more duplicate rows from 'user_ips' than expected" + ) + + # The previous step didn't delete enough rows, so we fallback to + # step 2: + + # Drop all the duplicates + txn.execute( + """ + DELETE FROM user_ips + WHERE user_id = ? AND access_token = ? AND ip = ? + """, + (user_id, access_token, ip), + ) + + # Add in one to be the last_seen + txn.execute( + """ + INSERT INTO user_ips + (user_id, access_token, ip, device_id, user_agent, last_seen) + VALUES (?, ?, ?, ?, ?, ?) + """, + (user_id, access_token, ip, device_id, user_agent, last_seen), + ) + + self.db_pool.updates._background_update_progress_txn( + txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} + ) + + yield self.db_pool.runInteraction("user_ips_dups_remove", remove) + + if last: + yield self.db_pool.updates._end_background_update("user_ips_remove_dupes") + + return batch_size + + @defer.inlineCallbacks + def _devices_last_seen_update(self, progress, batch_size): + """Background update to insert last seen info into devices table + """ + + last_user_id = progress.get("last_user_id", "") + last_device_id = progress.get("last_device_id", "") + + def _devices_last_seen_update_txn(txn): + # This consists of two queries: + # + # 1. The sub-query searches for the next N devices and joins + # against user_ips to find the max last_seen associated with + # that device. + # 2. The outer query then joins again against user_ips on + # user/device/last_seen. This *should* hopefully only + # return one row, but if it does return more than one then + # we'll just end up updating the same device row multiple + # times, which is fine. + + where_clause, where_args = make_tuple_comparison_clause( + self.database_engine, + [("user_id", last_user_id), ("device_id", last_device_id)], + ) + + sql = """ + SELECT + last_seen, ip, user_agent, user_id, device_id + FROM ( + SELECT + user_id, device_id, MAX(u.last_seen) AS last_seen + FROM devices + INNER JOIN user_ips AS u USING (user_id, device_id) + WHERE %(where_clause)s + GROUP BY user_id, device_id + ORDER BY user_id ASC, device_id ASC + LIMIT ? + ) c + INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) + """ % { + "where_clause": where_clause + } + txn.execute(sql, where_args + [batch_size]) + + rows = txn.fetchall() + if not rows: + return 0 + + sql = """ + UPDATE devices + SET last_seen = ?, ip = ?, user_agent = ? + WHERE user_id = ? AND device_id = ? + """ + txn.execute_batch(sql, rows) + + _, _, _, user_id, device_id = rows[-1] + self.db_pool.updates._background_update_progress_txn( + txn, + "devices_last_seen", + {"last_user_id": user_id, "last_device_id": device_id}, + ) + + return len(rows) + + updated = yield self.db_pool.runInteraction( + "_devices_last_seen_update", _devices_last_seen_update_txn + ) + + if not updated: + yield self.db_pool.updates._end_background_update("devices_last_seen") + + return updated + + +class ClientIpStore(ClientIpBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn, hs): + + self.client_ip_last_seen = Cache( + name="client_ip_last_seen", keylen=4, max_entries=50000 + ) + + super(ClientIpStore, self).__init__(database, db_conn, hs) + + self.user_ips_max_age = hs.config.user_ips_max_age + + # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) + self._batch_row_update = {} + + self._client_ip_looper = self._clock.looping_call( + self._update_client_ips_batch, 5 * 1000 + ) + self.hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", self._update_client_ips_batch + ) + + if self.user_ips_max_age: + self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + + @defer.inlineCallbacks + def insert_client_ip( + self, user_id, access_token, ip, user_agent, device_id, now=None + ): + if not now: + now = int(self._clock.time_msec()) + key = (user_id, access_token, ip) + + try: + last_seen = self.client_ip_last_seen.get(key) + except KeyError: + last_seen = None + yield self.populate_monthly_active_users(user_id) + # Rate-limited inserts + if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: + return + + self.client_ip_last_seen.prefill(key, now) + + self._batch_row_update[key] = (user_agent, device_id, now) + + @wrap_as_background_process("update_client_ips") + def _update_client_ips_batch(self): + + # If the DB pool has already terminated, don't try updating + if not self.db_pool.is_running(): + return + + to_update = self._batch_row_update + self._batch_row_update = {} + + return self.db_pool.runInteraction( + "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update + ) + + def _update_client_ips_batch_txn(self, txn, to_update): + if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( + not self.database_engine.can_native_upsert + ): + self.database_engine.lock_table(txn, "user_ips") + + for entry in to_update.items(): + (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry + + try: + self.db_pool.simple_upsert_txn( + txn, + table="user_ips", + keyvalues={ + "user_id": user_id, + "access_token": access_token, + "ip": ip, + }, + values={ + "user_agent": user_agent, + "device_id": device_id, + "last_seen": last_seen, + }, + lock=False, + ) + + # Technically an access token might not be associated with + # a device so we need to check. + if device_id: + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db_pool.simple_update_txn( + txn, + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues={ + "user_agent": user_agent, + "last_seen": last_seen, + "ip": ip, + }, + ) + except Exception as e: + # Failed to upsert, log and continue + logger.error("Failed to insert client IP %r: %r", entry, e) + + @defer.inlineCallbacks + def get_last_client_ip_by_device(self, user_id, device_id): + """For each device_id listed, give the user_ip it was last seen on + + Args: + user_id (str) + device_id (str): If None fetches all devices for the user + + Returns: + defer.Deferred: resolves to a dict, where the keys + are (user_id, device_id) tuples. The values are also dicts, with + keys giving the column names + """ + + keyvalues = {"user_id": user_id} + if device_id is not None: + keyvalues["device_id"] = device_id + + res = yield self.db_pool.simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ) + + ret = {(d["user_id"], d["device_id"]): d for d in res} + for key in self._batch_row_update: + uid, access_token, ip = key + if uid == user_id: + user_agent, did, last_seen = self._batch_row_update[key] + if not device_id or did == device_id: + ret[(user_id, device_id)] = { + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": did, + "last_seen": last_seen, + } + return ret + + @defer.inlineCallbacks + def get_user_ip_and_agents(self, user): + user_id = user.to_string() + results = {} + + for key in self._batch_row_update: + uid, access_token, ip, = key + if uid == user_id: + user_agent, _, last_seen = self._batch_row_update[key] + results[(access_token, ip)] = (user_agent, last_seen) + + rows = yield self.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "last_seen"], + desc="get_user_ip_and_agents", + ) + + results.update( + ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) + for row in rows + ) + return [ + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for (access_token, ip), (user_agent, last_seen) in results.items() + ] + + @wrap_as_background_process("prune_old_user_ips") + async def _prune_old_user_ips(self): + """Removes entries in user IPs older than the configured period. + """ + + if self.user_ips_max_age is None: + # Nothing to do + return + + if not await self.db_pool.updates.has_completed_background_update( + "devices_last_seen" + ): + # Only start pruning if we have finished populating the devices + # last seen info. + return + + # We do a slightly funky SQL delete to ensure we don't try and delete + # too much at once (as the table may be very large from before we + # started pruning). + # + # This works by finding the max last_seen that is less than the given + # time, but has no more than N rows before it, deleting all rows with + # a lesser last_seen time. (We COALESCE so that the sub-SELECT always + # returns exactly one row). + sql = """ + DELETE FROM user_ips + WHERE last_seen <= ( + SELECT COALESCE(MAX(last_seen), -1) + FROM ( + SELECT last_seen FROM user_ips + WHERE last_seen <= ? + ORDER BY last_seen ASC + LIMIT 5000 + ) AS u + ) + """ + + timestamp = self.clock.time_msec() - self.user_ips_max_age + + def _prune_old_user_ips_txn(txn): + txn.execute(sql, (timestamp,)) + + await self.db_pool.runInteraction( + "_prune_old_user_ips", _prune_old_user_ips_txn + ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py new file mode 100644 index 0000000000..874ecdf8d2 --- /dev/null +++ b/synapse/storage/databases/main/deviceinbox.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Tuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.util.caches.expiringcache import ExpiringCache + +logger = logging.getLogger(__name__) + + +class DeviceInboxWorkerStore(SQLBaseStore): + def get_to_device_stream_token(self): + return self._device_inbox_id_gen.get_current_token() + + def get_new_messages_for_device( + self, user_id, device_id, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + current_stream_id(int): The current position of the to device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_stream_id + ) + if not has_changed: + return defer.succeed(([], current_stream_id)) + + def get_new_messages_for_device_txn(txn): + sql = ( + "SELECT stream_id, message_json FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute( + sql, (user_id, device_id, last_stream_id, current_stream_id, limit) + ) + messages = [] + for row in txn: + stream_pos = row[0] + messages.append(db_to_json(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return messages, stream_pos + + return self.db_pool.runInteraction( + "get_new_messages_for_device", get_new_messages_for_device_txn + ) + + @trace + @defer.inlineCallbacks + def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves to the number of messages deleted. + """ + # If we have cached the last stream id we've deleted up to, we can + # check if there is likely to be anything that needs deleting + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), None + ) + + set_tag("last_deleted_stream_id", last_deleted_stream_id) + + if last_deleted_stream_id: + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_deleted_stream_id + ) + if not has_changed: + log_kv({"message": "No changes in cache since last check"}) + return 0 + + def delete_messages_for_device_txn(txn): + sql = ( + "DELETE FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (user_id, device_id, up_to_stream_id)) + return txn.rowcount + + count = yield self.db_pool.runInteraction( + "delete_messages_for_device", delete_messages_for_device_txn + ) + + log_kv( + {"message": "deleted {} messages for device".format(count), "count": count} + ) + + # Update the cache, ensuring that we only ever increase the value + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), 0 + ) + self._last_device_delete_cache[(user_id, device_id)] = max( + last_deleted_stream_id, up_to_stream_id + ) + + return count + + @trace + def get_new_device_msgs_for_remote( + self, destination, last_stream_id, current_stream_id, limit + ): + """ + Args: + destination(str): The name of the remote server. + last_stream_id(int|long): The last position of the device message stream + that the server sent up to. + current_stream_id(int|long): The current position of the device + message stream. + Returns: + Deferred ([dict], int|long): List of messages for the device and where + in the stream the messages got to. + """ + + set_tag("destination", destination) + set_tag("last_stream_id", last_stream_id) + set_tag("current_stream_id", current_stream_id) + set_tag("limit", limit) + + has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( + destination, last_stream_id + ) + if not has_changed or last_stream_id == current_stream_id: + log_kv({"message": "No new messages in stream"}) + return defer.succeed(([], current_stream_id)) + + if limit <= 0: + # This can happen if we run out of room for EDUs in the transaction. + return defer.succeed(([], last_stream_id)) + + @trace + def get_new_messages_for_remote_destination_txn(txn): + sql = ( + "SELECT stream_id, messages_json FROM device_federation_outbox" + " WHERE destination = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (destination, last_stream_id, current_stream_id, limit)) + messages = [] + for row in txn: + stream_pos = row[0] + messages.append(db_to_json(row[1])) + if len(messages) < limit: + log_kv({"message": "Set stream position to current position"}) + stream_pos = current_stream_id + return messages, stream_pos + + return self.db_pool.runInteraction( + "get_new_device_msgs_for_remote", + get_new_messages_for_remote_destination_txn, + ) + + @trace + def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + """Used to delete messages when the remote destination acknowledges + their receipt. + + Args: + destination(str): The destination server_name + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + + def delete_messages_for_remote_destination_txn(txn): + sql = ( + "DELETE FROM device_federation_outbox" + " WHERE destination = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (destination, up_to_stream_id)) + + return self.db_pool.runInteraction( + "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn + ) + + async def get_all_new_device_messages( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for to device replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_id, last_id + limit) + sql = ( + "SELECT max(stream_id), user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY user_id" + ) + txn.execute(sql, (last_id, upper_pos)) + updates = [(row[0], row[1:]) for row in txn] + + sql = ( + "SELECT max(stream_id), destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY destination" + ) + txn.execute(sql, (last_id, upper_pos)) + updates.extend((row[0], row[1:]) for row in txn) + + # Order by ascending stream ordering + updates.sort() + + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + + +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.db_pool.updates.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox + ) + + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") + txn.close() + + yield self.db_pool.runWithConnection(reindex_txn) + + yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + return 1 + + +class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(DeviceInboxStore, self).__init__(database, db_conn, hs) + + # Map of (user_id, device_id) to the last stream_id that has been + # deleted up to. This is so that we can no op deletions. + self._last_device_delete_cache = ExpiringCache( + cache_name="last_device_delete_cache", + clock=self._clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) + + @trace + @defer.inlineCallbacks + def add_messages_to_device_inbox( + self, local_messages_by_user_then_device, remote_messages_by_destination + ): + """Used to send messages from this server. + + Args: + sender_user_id(str): The ID of the user sending these messages. + local_messages_by_user_and_device(dict): + Dictionary of user_id to device_id to message. + remote_messages_by_destination(dict): + Dictionary of destination server_name to the EDU JSON to send. + Returns: + A deferred stream_id that resolves when the messages have been + inserted. + """ + + def add_messages_txn(txn, now_ms, stream_id): + # Add the local messages directly to the local inbox. + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + + # Add the remote messages to the federation outbox. + # We'll send them to a remote server when we next send a + # federation transaction to that destination. + sql = ( + "INSERT INTO device_federation_outbox" + " (destination, stream_id, queued_ts, messages_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for destination, edu in remote_messages_by_destination.items(): + edu_json = json.dumps(edu) + rows.append((destination, stream_id, now_ms, edu_json)) + txn.executemany(sql, rows) + + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_msec() + yield self.db_pool.runInteraction( + "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id + ) + for user_id in local_messages_by_user_then_device.keys(): + self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) + for destination in remote_messages_by_destination.keys(): + self._device_federation_outbox_stream_cache.entity_has_changed( + destination, stream_id + ) + + return self._device_inbox_id_gen.get_current_token() + + @defer.inlineCallbacks + def add_messages_from_remote_to_device_inbox( + self, origin, message_id, local_messages_by_user_then_device + ): + def add_messages_txn(txn, now_ms, stream_id): + # Check if we've already inserted a matching message_id for that + # origin. This can happen if the origin doesn't receive our + # acknowledgement from the first time we received the message. + already_inserted = self.db_pool.simple_select_one_txn( + txn, + table="device_federation_inbox", + keyvalues={"origin": origin, "message_id": message_id}, + retcols=("message_id",), + allow_none=True, + ) + if already_inserted is not None: + return + + # Add an entry for this message_id so that we know we've processed + # it. + self.db_pool.simple_insert_txn( + txn, + table="device_federation_inbox", + values={ + "origin": origin, + "message_id": message_id, + "received_ts": now_ms, + }, + ) + + # Add the messages to the approriate local device inboxes so that + # they'll be sent to the devices when they next sync. + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_msec() + yield self.db_pool.runInteraction( + "add_messages_from_remote_to_device_inbox", + add_messages_txn, + now_ms, + stream_id, + ) + for user_id in local_messages_by_user_then_device.keys(): + self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) + + return stream_id + + def _add_messages_to_local_device_inbox_txn( + self, txn, stream_id, messages_by_user_then_device + ): + local_by_user_then_device = {} + for user_id, messages_by_device in messages_by_user_then_device.items(): + messages_json_for_user = {} + devices = list(messages_by_device.keys()) + if len(devices) == 1 and devices[0] == "*": + # Handle wildcard device_ids. + sql = "SELECT device_id FROM devices WHERE user_id = ?" + txn.execute(sql, (user_id,)) + message_json = json.dumps(messages_by_device["*"]) + for row in txn: + # Add the message for all devices for this user on this + # server. + device = row[0] + messages_json_for_user[device] = message_json + else: + if not devices: + continue + + clause, args = make_in_list_sql_clause( + txn.database_engine, "device_id", devices + ) + sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause + + # TODO: Maybe this needs to be done in batches if there are + # too many local devices for a given user. + txn.execute(sql, [user_id] + list(args)) + for row in txn: + # Only insert into the local inbox if the device exists on + # this server + device = row[0] + message_json = json.dumps(messages_by_device[device]) + messages_json_for_user[device] = message_json + + if messages_json_for_user: + local_by_user_then_device[user_id] = messages_json_for_user + + if not local_by_user_then_device: + return + + sql = ( + "INSERT INTO device_inbox" + " (user_id, device_id, stream_id, message_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for user_id, messages_by_device in local_by_user_then_device.items(): + for device_id, message_json in messages_by_device.items(): + rows.append((user_id, device_id, stream_id, message_json)) + + txn.executemany(sql, rows) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py new file mode 100644 index 0000000000..88a7aadfc6 --- /dev/null +++ b/synapse/storage/databases/main/devices.py @@ -0,0 +1,1311 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import List, Optional, Set, Tuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.errors import Codes, StoreError +from synapse.logging.opentracing import ( + get_active_span_text_map, + set_tag, + trace, + whitelisted_homeserver, +) +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.types import Collection, get_verify_key_from_cross_signing_key +from synapse.util.caches.descriptors import ( + Cache, + cached, + cachedInlineCallbacks, + cachedList, +) +from synapse.util.iterutils import batch_iter +from synapse.util.stringutils import shortstr + +logger = logging.getLogger(__name__) + +DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( + "drop_device_list_streams_non_unique_indexes" +) + +BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" + + +class DeviceWorkerStore(SQLBaseStore): + def get_device(self, user_id, device_id): + """Retrieve a device. Only returns devices that are not marked as + hidden. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred for a dict containing the device information + Raises: + StoreError: if the device is not found + """ + return self.db_pool.simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + ) + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """Retrieve all of a user's registered devices. Only returns devices + that are not marked as hidden. + + Args: + user_id (str): + Returns: + defer.Deferred: resolves to a dict from device_id to a dict + containing "device_id", "user_id" and "display_name" for each + device. + """ + devices = yield self.db_pool.simple_select_list( + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + retcols=("user_id", "device_id", "display_name"), + desc="get_devices_by_user", + ) + + return {d["device_id"]: d for d in devices} + + @trace + @defer.inlineCallbacks + def get_device_updates_by_remote(self, destination, from_stream_id, limit): + """Get a stream of device updates to send to the given remote server. + + Args: + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + limit (int): Maximum number of device updates to return + Returns: + Deferred[tuple[int, list[tuple[string,dict]]]]: + current stream id (ie, the stream id of the last update included in the + response), and the list of updates, where each update is a pair of EDU + type and EDU contents + """ + now_stream_id = self._device_list_id_gen.get_current_token() + + has_changed = self._device_list_federation_stream_cache.has_entity_changed( + destination, int(from_stream_id) + ) + if not has_changed: + return now_stream_id, [] + + updates = yield self.db_pool.runInteraction( + "get_device_updates_by_remote", + self._get_device_updates_by_remote_txn, + destination, + from_stream_id, + now_stream_id, + limit, + ) + + # Return an empty list if there are no updates + if not updates: + return now_stream_id, [] + + # get the cross-signing keys of the users in the list, so that we can + # determine which of the device changes were cross-signing keys + users = {r[0] for r in updates} + master_key_by_user = {} + self_signing_key_by_user = {} + for user in users: + cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + if cross_signing_key: + key_id, verify_key = get_verify_key_from_cross_signing_key( + cross_signing_key + ) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + master_key_by_user[user] = { + "key_info": cross_signing_key, + "device_id": verify_key.version, + } + + cross_signing_key = yield self.get_e2e_cross_signing_key( + user, "self_signing" + ) + if cross_signing_key: + key_id, verify_key = get_verify_key_from_cross_signing_key( + cross_signing_key + ) + self_signing_key_by_user[user] = { + "key_info": cross_signing_key, + "device_id": verify_key.version, + } + + # Perform the equivalent of a GROUP BY + # + # Iterate through the updates list and copy non-duplicate + # (user_id, device_id) entries into a map, with the value being + # the max stream_id across each set of duplicate entries + # + # maps (user_id, device_id) -> (stream_id, opentracing_context) + # + # opentracing_context contains the opentracing metadata for the request + # that created the poke + # + # The most recent request's opentracing_context is used as the + # context which created the Edu. + + query_map = {} + cross_signing_keys_by_user = {} + for user_id, device_id, update_stream_id, update_context in updates: + if ( + user_id in master_key_by_user + and device_id == master_key_by_user[user_id]["device_id"] + ): + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["master_key"] = master_key_by_user[user_id]["key_info"] + elif ( + user_id in self_signing_key_by_user + and device_id == self_signing_key_by_user[user_id]["device_id"] + ): + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["self_signing_key"] = self_signing_key_by_user[user_id][ + "key_info" + ] + else: + key = (user_id, device_id) + + previous_update_stream_id, _ = query_map.get(key, (0, None)) + + if update_stream_id > previous_update_stream_id: + query_map[key] = (update_stream_id, update_context) + + results = yield self._get_device_update_edus_by_remote( + destination, from_stream_id, query_map + ) + + # add the updated cross-signing keys to the results list + for user_id, result in cross_signing_keys_by_user.items(): + result["user_id"] = user_id + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("org.matrix.signing_key_update", result)) + + return now_stream_id, results + + def _get_device_updates_by_remote_txn( + self, txn, destination, from_stream_id, now_stream_id, limit + ): + """Return device update information for a given remote destination + + Args: + txn (LoggingTransaction): The transaction to execute + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + now_stream_id (int): The maximum stream_id to filter updates by, inclusive + limit (int): Maximum number of device updates to return + + Returns: + List: List of device updates + """ + # get the list of device updates that need to be sent + sql = """ + SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes + WHERE destination = ? AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id + LIMIT ? + """ + txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) + + return list(txn) + + @defer.inlineCallbacks + def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): + """Returns a list of device update EDUs as well as E2EE keys + + Args: + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping + user_id/device_id to update stream_id and the relevent json-encoded + opentracing context + + Returns: + List[Dict]: List of objects representing an device update EDU + + """ + devices = ( + yield self.db_pool.runInteraction( + "_get_e2e_device_keys_txn", + self._get_e2e_device_keys_txn, + query_map.keys(), + include_all_devices=True, + include_deleted_devices=True, + ) + if query_map + else {} + ) + + results = [] + for user_id, user_devices in devices.items(): + # The prev_id for the first row is always the last row before + # `from_stream_id` + prev_id = yield self._get_last_device_update_for_remote_user( + destination, user_id, from_stream_id + ) + + # make sure we go through the devices in stream order + device_ids = sorted( + user_devices.keys(), key=lambda i: query_map[(user_id, i)][0], + ) + + for device_id in device_ids: + device = user_devices[device_id] + stream_id, opentracing_context = query_map[(user_id, device_id)] + result = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_id] if prev_id else [], + "stream_id": stream_id, + "org.matrix.opentracing_context": opentracing_context, + } + + prev_id = stream_id + + if device is not None: + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + else: + result["deleted"] = True + + results.append(("m.device_list_update", result)) + + return results + + def _get_last_device_update_for_remote_user( + self, destination, user_id, from_stream_id + ): + def f(txn): + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? AND stream_id <= ? + """ + txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) + rows = txn.fetchall() + return rows[0][0] + + return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) + + def mark_as_sent_devices_by_remote(self, destination, stream_id): + """Mark that updates have successfully been sent to the destination. + """ + return self.db_pool.runInteraction( + "mark_as_sent_devices_by_remote", + self._mark_as_sent_devices_by_remote_txn, + destination, + stream_id, + ) + + def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + # We update the device_lists_outbound_last_success with the successfully + # poked users. + sql = """ + SELECT user_id, coalesce(max(o.stream_id), 0) + FROM device_lists_outbound_pokes as o + WHERE destination = ? AND o.stream_id <= ? + GROUP BY user_id + """ + txn.execute(sql, (destination, stream_id)) + rows = txn.fetchall() + + self.db_pool.simple_upsert_many_txn( + txn=txn, + table="device_lists_outbound_last_success", + key_names=("destination", "user_id"), + key_values=((destination, user_id) for user_id, _ in rows), + value_names=("stream_id",), + value_values=((stream_id,) for _, stream_id in rows), + ) + + # Delete all sent outbound pokes + sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + """ + txn.execute(sql, (destination, stream_id)) + + @defer.inlineCallbacks + def add_user_signature_change_to_streams(self, from_user_id, user_ids): + """Persist that a user has made new signatures + + Args: + from_user_id (str): the user who made the signatures + user_ids (list[str]): the users who were signed + """ + + with self._device_list_id_gen.get_next() as stream_id: + yield self.db_pool.runInteraction( + "add_user_sig_change_to_streams", + self._add_user_signature_change_txn, + from_user_id, + user_ids, + stream_id, + ) + return stream_id + + def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): + txn.call_after( + self._user_signature_stream_cache.entity_has_changed, + from_user_id, + stream_id, + ) + self.db_pool.simple_insert_txn( + txn, + "user_signature_stream", + values={ + "stream_id": stream_id, + "from_user_id": from_user_id, + "user_ids": json.dumps(user_ids), + }, + ) + + def get_device_stream_token(self): + return self._device_list_id_gen.get_current_token() + + @trace + @defer.inlineCallbacks + def get_user_devices_from_cache(self, query_list): + """Get the devices (and keys if any) for remote users from the cache. + + Args: + query_list(list): List of (user_id, device_ids), if device_ids is + falsey then return all device ids for that user. + + Returns: + (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is + a set of user_ids and results_map is a mapping of + user_id -> device_id -> device_info + """ + user_ids = {user_id for user_id, _ in query_list} + user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + + # We go and check if any of the users need to have their device lists + # resynced. If they do then we remove them from the cached list. + users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( + user_ids + ) + user_ids_in_cache = { + user_id for user_id, stream_id in user_map.items() if stream_id + } - users_needing_resync + user_ids_not_in_cache = user_ids - user_ids_in_cache + + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + device = yield self._get_cached_user_device(user_id, device_id) + results.setdefault(user_id, {})[device_id] = device + else: + results[user_id] = yield self.get_cached_devices_for_user(user_id) + + set_tag("in_cache", results) + set_tag("not_in_cache", user_ids_not_in_cache) + + return user_ids_not_in_cache, results + + @cachedInlineCallbacks(num_args=2, tree=True) + def _get_cached_user_device(self, user_id, device_id): + content = yield self.db_pool.simple_select_one_onecol( + table="device_lists_remote_cache", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="content", + desc="_get_cached_user_device", + ) + return db_to_json(content) + + @cachedInlineCallbacks() + def get_cached_devices_for_user(self, user_id): + devices = yield self.db_pool.simple_select_list( + table="device_lists_remote_cache", + keyvalues={"user_id": user_id}, + retcols=("device_id", "content"), + desc="get_cached_devices_for_user", + ) + return { + device["device_id"]: db_to_json(device["content"]) for device in devices + } + + def get_devices_with_keys_by_user(self, user_id): + """Get all devices (with any device keys) for a user + + Returns: + (stream_id, devices) + """ + return self.db_pool.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, + user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True + ) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.items(): + result = {"device_id": device_id} + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + + def get_users_whose_devices_changed(self, from_key, user_ids): + """Get set of users whose devices have changed since `from_key` that + are in the given list of user_ids. + + Args: + from_key (str): The device lists stream token + user_ids (Iterable[str]) + + Returns: + Deferred[set[str]]: The set of user_ids whose devices have changed + since `from_key` + """ + from_key = int(from_key) + + # Get set of users who *may* have changed. Users not in the returned + # list have definitely not changed. + to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) + + if not to_check: + return defer.succeed(set()) + + def _get_users_whose_devices_changed_txn(txn): + changes = set() + + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream + WHERE stream_id > ? + AND + """ + + for chunk in batch_iter(to_check, 100): + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + txn.execute(sql + clause, (from_key,) + tuple(args)) + changes.update(user_id for user_id, in txn) + + return changes + + return self.db_pool.runInteraction( + "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn + ) + + @defer.inlineCallbacks + def get_users_whose_signatures_changed(self, user_id, from_key): + """Get the users who have new cross-signing signatures made by `user_id` since + `from_key`. + + Args: + user_id (str): the user who made the signatures + from_key (str): The device lists stream token + """ + from_key = int(from_key) + if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): + sql = """ + SELECT DISTINCT user_ids FROM user_signature_stream + WHERE from_user_id = ? AND stream_id > ? + """ + rows = yield self.db_pool.execute( + "get_users_whose_signatures_changed", None, sql, user_id, from_key + ) + return {user for row in rows for user in db_to_json(row[0])} + else: + return set() + + async def get_all_device_list_changes_for_remotes( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for device lists replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def _get_all_device_list_changes_for_remotes(txn): + # This query Does The Right Thing where it'll correctly apply the + # bounds to the inner queries. + sql = """ + SELECT stream_id, entity FROM ( + SELECT stream_id, user_id AS entity FROM device_lists_stream + UNION ALL + SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes + ) AS e + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + + txn.execute(sql, (last_id, current_id, limit)) + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_all_device_list_changes_for_remotes", + _get_all_device_list_changes_for_remotes, + ) + + @cached(max_entries=10000) + def get_device_list_last_stream_id_for_remote(self, user_id): + """Get the last stream_id we got for a user. May be None if we haven't + got any information for them. + """ + return self.db_pool.simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_last_stream_id_for_remote", + allow_none=True, + ) + + @cachedList( + cached_method_name="get_device_list_last_stream_id_for_remote", + list_name="user_ids", + inlineCallbacks=True, + ) + def get_device_list_last_stream_id_for_remotes(self, user_ids): + rows = yield self.db_pool.simple_select_many_batch( + table="device_lists_remote_extremeties", + column="user_id", + iterable=user_ids, + retcols=("user_id", "stream_id"), + desc="get_device_list_last_stream_id_for_remotes", + ) + + results = {user_id: None for user_id in user_ids} + results.update({row["user_id"]: row["stream_id"] for row in rows}) + + return results + + @defer.inlineCallbacks + def get_user_ids_requiring_device_list_resync( + self, user_ids: Optional[Collection[str]] = None, + ) -> Set[str]: + """Given a list of remote users return the list of users that we + should resync the device lists for. If None is given instead of a list, + return every user that we should resync the device lists for. + + Returns: + The IDs of users whose device lists need resync. + """ + if user_ids: + rows = yield self.db_pool.simple_select_many_batch( + table="device_lists_remote_resync", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="get_user_ids_requiring_device_list_resync_with_iterable", + ) + else: + rows = yield self.db_pool.simple_select_list( + table="device_lists_remote_resync", + keyvalues=None, + retcols=("user_id",), + desc="get_user_ids_requiring_device_list_resync", + ) + + return {row["user_id"] for row in rows} + + def mark_remote_user_device_cache_as_stale(self, user_id: str): + """Records that the server has reason to believe the cache of the devices + for the remote users is out of date. + """ + return self.db_pool.simple_upsert( + table="device_lists_remote_resync", + keyvalues={"user_id": user_id}, + values={}, + insertion_values={"added_ts": self._clock.time_msec()}, + desc="make_remote_user_device_cache_as_stale", + ) + + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ + + def _mark_remote_user_device_list_as_unsubscribed_txn(txn): + self.db_pool.simple_delete_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_device_list_last_stream_id_for_remote, (user_id,) + ) + + return self.db_pool.runInteraction( + "mark_remote_user_device_list_as_unsubscribed", + _mark_remote_user_device_list_as_unsubscribed_txn, + ) + + +class DeviceBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "device_lists_stream_idx", + index_name="device_lists_stream_user_id", + table="device_lists_stream", + columns=["user_id", "device_id"], + ) + + # create a unique index on device_lists_remote_cache + self.db_pool.updates.register_background_index_update( + "device_lists_remote_cache_unique_idx", + index_name="device_lists_remote_cache_unique_id", + table="device_lists_remote_cache", + columns=["user_id", "device_id"], + unique=True, + ) + + # And one on device_lists_remote_extremeties + self.db_pool.updates.register_background_index_update( + "device_lists_remote_extremeties_unique_idx", + index_name="device_lists_remote_extremeties_unique_idx", + table="device_lists_remote_extremeties", + columns=["user_id"], + unique=True, + ) + + # once they complete, we can remove the old non-unique indexes. + self.db_pool.updates.register_background_update_handler( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, + self._drop_device_list_streams_non_unique_indexes, + ) + + # clear out duplicate device list outbound pokes + self.db_pool.updates.register_background_update_handler( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, + ) + + # a pair of background updates that were added during the 1.14 release cycle, + # but replaced with 58/06dlols_unique_idx.py + self.db_pool.updates.register_noop_background_update( + "device_lists_outbound_last_success_unique_idx", + ) + self.db_pool.updates.register_noop_background_update( + "drop_device_lists_outbound_last_success_non_unique_idx", + ) + + @defer.inlineCallbacks + def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + def f(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") + txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") + txn.close() + + yield self.db_pool.runWithConnection(f) + yield self.db_pool.updates._end_background_update( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES + ) + return 1 + + async def _remove_duplicate_outbound_pokes(self, progress, batch_size): + # for some reason, we have accumulated duplicate entries in + # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + # efficient. + # + # For each duplicate, we delete all the existing rows and put one back. + + KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] + last_row = progress.get( + "last_row", + {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, + ) + + def _txn(txn): + clause, args = make_tuple_comparison_clause( + self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS] + ) + sql = """ + SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts + FROM device_lists_outbound_pokes + WHERE %s + GROUP BY %s + HAVING count(*) > 1 + ORDER BY %s + LIMIT ? + """ % ( + clause, # WHERE + ",".join(KEY_COLS), # GROUP BY + ",".join(KEY_COLS), # ORDER BY + ) + txn.execute(sql, args + [batch_size]) + rows = self.db_pool.cursor_to_dict(txn) + + row = None + for row in rows: + self.db_pool.simple_delete_txn( + txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, + ) + + row["sent"] = False + self.db_pool.simple_insert_txn( + txn, "device_lists_outbound_pokes", row, + ) + + if row: + self.db_pool.updates._background_update_progress_txn( + txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, + ) + + return len(rows) + + rows = await self.db_pool.runInteraction( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn + ) + + if not rows: + await self.db_pool.updates._end_background_update( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES + ) + + return rows + + +class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(DeviceStore, self).__init__(database, db_conn, hs) + + # Map of (user_id, device_id) -> bool. If there is an entry that implies + # the device exists. + self.device_id_exists_cache = Cache( + name="device_id_exists", keylen=2, max_entries=10000 + ) + + self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) + + @defer.inlineCallbacks + def store_device(self, user_id, device_id, initial_device_display_name): + """Ensure the given device is known; add it to the store if not + + Args: + user_id (str): id of user associated with the device + device_id (str): id of device + initial_device_display_name (str): initial displayname of the + device. Ignored if device exists. + Returns: + defer.Deferred: boolean whether the device was inserted or an + existing device existed with that ID. + Raises: + StoreError: if the device is already in use + """ + key = (user_id, device_id) + if self.device_id_exists_cache.get(key, None): + return False + + try: + inserted = yield self.db_pool.simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": device_id, + "display_name": initial_device_display_name, + "hidden": False, + }, + desc="store_device", + or_ignore=True, + ) + if not inserted: + # if the device already exists, check if it's a real device, or + # if the device ID is reserved by something else + hidden = yield self.db_pool.simple_select_one_onecol( + "devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="hidden", + ) + if hidden: + raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) + self.device_id_exists_cache.prefill(key, True) + return inserted + except StoreError: + raise + except Exception as e: + logger.error( + "store_device with device_id=%s(%r) user_id=%s(%r)" + " display_name=%s(%r) failed: %s", + type(device_id).__name__, + device_id, + type(user_id).__name__, + user_id, + type(initial_device_display_name).__name__, + initial_device_display_name, + e, + ) + raise StoreError(500, "Problem storing device.") + + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """Delete a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to delete + Returns: + defer.Deferred + """ + yield self.db_pool.simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, + desc="delete_device", + ) + + self.device_id_exists_cache.invalidate((user_id, device_id)) + + @defer.inlineCallbacks + def delete_devices(self, user_id, device_ids): + """Deletes several devices. + + Args: + user_id (str): The ID of the user which owns the devices + device_ids (list): The IDs of the devices to delete + Returns: + defer.Deferred + """ + yield self.db_pool.simple_delete_many( + table="devices", + column="device_id", + iterable=device_ids, + keyvalues={"user_id": user_id, "hidden": False}, + desc="delete_devices", + ) + for device_id in device_ids: + self.device_id_exists_cache.invalidate((user_id, device_id)) + + def update_device(self, user_id, device_id, new_display_name=None): + """Update a device. Only updates the device if it is not marked as + hidden. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to update + new_display_name (str|None): new displayname for device; None + to leave unchanged + Raises: + StoreError: if the device is not found + Returns: + defer.Deferred + """ + updates = {} + if new_display_name is not None: + updates["display_name"] = new_display_name + if not updates: + return defer.succeed(None) + return self.db_pool.simple_update_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, + updatevalues=updates, + desc="update_device", + ) + + def update_remote_device_list_cache_entry( + self, user_id, device_id, content, stream_id + ): + """Updates a single device in the cache of a remote user's devicelist. + + Note: assumes that we are the only thread that can be updating this user's + device list. + + Args: + user_id (str): User to update device list for + device_id (str): ID of decivice being updated + content (dict): new data on this device + stream_id (int): the version of the device list + + Returns: + Deferred[None] + """ + return self.db_pool.runInteraction( + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, + device_id, + content, + stream_id, + ) + + def _update_remote_device_list_cache_entry_txn( + self, txn, user_id, device_id, content, stream_id + ): + if content.get("deleted"): + self.db_pool.simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + + txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) + else: + self.db_pool.simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"content": json.dumps(content)}, + # we don't need to lock, because we assume we are the only thread + # updating this user's devices. + lock=False, + ) + + txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) + txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + + self.db_pool.simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + values={"stream_id": stream_id}, + # again, we can assume we are the only thread updating this user's + # extremity. + lock=False, + ) + + def update_remote_device_list_cache(self, user_id, devices, stream_id): + """Replace the entire cache of the remote user's devices. + + Note: assumes that we are the only thread that can be updating this user's + device list. + + Args: + user_id (str): User to update device list for + devices (list[dict]): list of device objects supplied over federation + stream_id (int): the version of the device list + + Returns: + Deferred[None] + """ + return self.db_pool.runInteraction( + "update_remote_device_list_cache", + self._update_remote_device_list_cache_txn, + user_id, + devices, + stream_id, + ) + + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): + self.db_pool.simple_delete_txn( + txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="device_lists_remote_cache", + values=[ + { + "user_id": user_id, + "device_id": content["device_id"], + "content": json.dumps(content), + } + for content in devices + ], + ) + + txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + + self.db_pool.simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + values={"stream_id": stream_id}, + # we don't need to lock, because we can assume we are the only thread + # updating this user's extremity. + lock=False, + ) + + # If we're replacing the remote user's device list cache presumably + # we've done a full resync, so we remove the entry that says we need + # to resync + self.db_pool.simple_delete_txn( + txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, + ) + + @defer.inlineCallbacks + def add_device_change_to_streams(self, user_id, device_ids, hosts): + """Persist that a user's devices have been updated, and which hosts + (if any) should be poked. + """ + if not device_ids: + return + + with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: + yield self.db_pool.runInteraction( + "add_device_change_to_stream", + self._add_device_change_to_stream_txn, + user_id, + device_ids, + stream_ids, + ) + + if not hosts: + return stream_ids[-1] + + context = get_active_span_text_map() + with self._device_list_id_gen.get_next_mult( + len(hosts) * len(device_ids) + ) as stream_ids: + yield self.db_pool.runInteraction( + "add_device_outbound_poke_to_stream", + self._add_device_outbound_poke_to_stream_txn, + user_id, + device_ids, + hosts, + stream_ids, + context, + ) + + return stream_ids[-1] + + def _add_device_change_to_stream_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + stream_ids: List[str], + ): + txn.call_after( + self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], + ) + + min_stream_id = stream_ids[0] + + # Delete older entries in the table, as we really only care about + # when the latest change happened. + txn.executemany( + """ + DELETE FROM device_lists_stream + WHERE user_id = ? AND device_id = ? AND stream_id < ? + """, + [(user_id, device_id, min_stream_id) for device_id in device_ids], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="device_lists_stream", + values=[ + {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} + for stream_id, device_id in zip(stream_ids, device_ids) + ], + ) + + def _add_device_outbound_poke_to_stream_txn( + self, txn, user_id, device_ids, hosts, stream_ids, context, + ): + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_ids[-1], + ) + + now = self._clock.time_msec() + next_stream_id = iter(stream_ids) + + self.db_pool.simple_insert_many_txn( + txn, + table="device_lists_outbound_pokes", + values=[ + { + "destination": destination, + "stream_id": next(next_stream_id), + "user_id": user_id, + "device_id": device_id, + "sent": False, + "ts": now, + "opentracing_context": json.dumps(context) + if whitelisted_homeserver(destination) + else "{}", + } + for destination in hosts + for device_id in device_ids + ], + ) + + def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): + """Delete old entries out of the device_lists_outbound_pokes to ensure + that we don't fill up due to dead servers. + + Normally, we try to send device updates as a delta since a previous known point: + this is done by setting the prev_id in the m.device_list_update EDU. However, + for that to work, we have to have a complete record of each change to + each device, which can add up to quite a lot of data. + + An alternative mechanism is that, if the remote server sees that it has missed + an entry in the stream_id sequence for a given user, it will request a full + list of that user's devices. Hence, we can reduce the amount of data we have to + store (and transmit in some future transaction), by clearing almost everything + for a given destination out of the database, and having the remote server + resync. + + All we need to do is make sure we keep at least one row for each + (user, destination) pair, to remind us to send a m.device_list_update EDU for + that user when the destination comes back. It doesn't matter which device + we keep. + """ + yesterday = self._clock.time_msec() - prune_age + + def _prune_txn(txn): + # look for (user, destination) pairs which have an update older than + # the cutoff. + # + # For each pair, we also need to know the most recent stream_id, and + # an arbitrary device_id at that stream_id. + select_sql = """ + SELECT + dlop1.destination, + dlop1.user_id, + MAX(dlop1.stream_id) AS stream_id, + (SELECT MIN(dlop2.device_id) AS device_id FROM + device_lists_outbound_pokes dlop2 + WHERE dlop2.destination = dlop1.destination AND + dlop2.user_id=dlop1.user_id AND + dlop2.stream_id=MAX(dlop1.stream_id) + ) + FROM device_lists_outbound_pokes dlop1 + GROUP BY destination, user_id + HAVING min(ts) < ? AND count(*) > 1 + """ + + txn.execute(select_sql, (yesterday,)) + rows = txn.fetchall() + + if not rows: + return + + logger.info( + "Pruning old outbound device list updates for %i users/destinations: %s", + len(rows), + shortstr((row[0], row[1]) for row in rows), + ) + + # we want to keep the update with the highest stream_id for each user. + # + # there might be more than one update (with different device_ids) with the + # same stream_id, so we also delete all but one rows with the max stream id. + delete_sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND ( + stream_id < ? OR + (stream_id = ? AND device_id != ?) + ) + """ + count = 0 + for (destination, user_id, stream_id, device_id) in rows: + txn.execute( + delete_sql, (destination, user_id, stream_id, stream_id, device_id) + ) + count += txn.rowcount + + # Since we've deleted unsent deltas, we need to remove the entry + # of last successful sent so that the prev_ids are correctly set. + sql = """ + DELETE FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? + """ + txn.executemany(sql, ((row[0], row[1]) for row in rows)) + + logger.info("Pruned %d device list outbound pokes", count) + + return run_as_background_process( + "prune_old_outbound_device_pokes", + self.db_pool.runInteraction, + "_prune_old_outbound_device_pokes", + _prune_txn, + ) diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py new file mode 100644 index 0000000000..7819bfcbb3 --- /dev/null +++ b/synapse/storage/databases/main/directory.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from typing import Optional + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore +from synapse.util.caches.descriptors import cached + +RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) + + +class DirectoryWorkerStore(SQLBaseStore): + @defer.inlineCallbacks + def get_association_from_room_alias(self, room_alias): + """ Get's the room_id and server list for a given room_alias + + Args: + room_alias (RoomAlias) + + Returns: + Deferred: results in namedtuple with keys "room_id" and + "servers" or None if no association can be found + """ + room_id = yield self.db_pool.simple_select_one_onecol( + "room_aliases", + {"room_alias": room_alias.to_string()}, + "room_id", + allow_none=True, + desc="get_association_from_room_alias", + ) + + if not room_id: + return None + + servers = yield self.db_pool.simple_select_onecol( + "room_alias_servers", + {"room_alias": room_alias.to_string()}, + "server", + desc="get_association_from_room_alias", + ) + + if not servers: + return None + + return RoomAliasMapping(room_id, room_alias.to_string(), servers) + + def get_room_alias_creator(self, room_alias): + return self.db_pool.simple_select_one_onecol( + table="room_aliases", + keyvalues={"room_alias": room_alias}, + retcol="creator", + desc="get_room_alias_creator", + ) + + @cached(max_entries=5000) + def get_aliases_for_room(self, room_id): + return self.db_pool.simple_select_onecol( + "room_aliases", + {"room_id": room_id}, + "room_alias", + desc="get_aliases_for_room", + ) + + +class DirectoryStore(DirectoryWorkerStore): + @defer.inlineCallbacks + def create_room_alias_association(self, room_alias, room_id, servers, creator=None): + """ Creates an association between a room alias and room_id/servers + + Args: + room_alias (RoomAlias) + room_id (str) + servers (list) + creator (str): Optional user_id of creator. + + Returns: + Deferred + """ + + def alias_txn(txn): + self.db_pool.simple_insert_txn( + txn, + "room_aliases", + { + "room_alias": room_alias.to_string(), + "room_id": room_id, + "creator": creator, + }, + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="room_alias_servers", + values=[ + {"room_alias": room_alias.to_string(), "server": server} + for server in servers + ], + ) + + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (room_id,) + ) + + try: + ret = yield self.db_pool.runInteraction( + "create_room_alias_association", alias_txn + ) + except self.database_engine.module.IntegrityError: + raise SynapseError( + 409, "Room alias %s already exists" % room_alias.to_string() + ) + return ret + + @defer.inlineCallbacks + def delete_room_alias(self, room_alias): + room_id = yield self.db_pool.runInteraction( + "delete_room_alias", self._delete_room_alias_txn, room_alias + ) + + return room_id + + def _delete_room_alias_txn(self, txn, room_alias): + txn.execute( + "SELECT room_id FROM room_aliases WHERE room_alias = ?", + (room_alias.to_string(),), + ) + + res = txn.fetchone() + if res: + room_id = res[0] + else: + return None + + txn.execute( + "DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),) + ) + + txn.execute( + "DELETE FROM room_alias_servers WHERE room_alias = ?", + (room_alias.to_string(),), + ) + + self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,)) + + return room_id + + def update_aliases_for_room( + self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, + ): + """Repoint all of the aliases for a given room, to a different room. + + Args: + old_room_id: + new_room_id: + creator: The user to record as the creator of the new mapping. + If None, the creator will be left unchanged. + """ + + def _update_aliases_for_room_txn(txn): + update_creator_sql = "" + sql_params = (new_room_id, old_room_id) + if creator: + update_creator_sql = ", creator = ?" + sql_params = (new_room_id, creator, old_room_id) + + sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % ( + update_creator_sql, + ) + txn.execute(sql, sql_params) + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (old_room_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (new_room_id,) + ) + + return self.db_pool.runInteraction( + "_update_aliases_for_room_txn", _update_aliases_for_room_txn + ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py new file mode 100644 index 0000000000..90152edc3c --- /dev/null +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -0,0 +1,439 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.logging.opentracing import log_kv, trace +from synapse.storage._base import SQLBaseStore, db_to_json + + +class EndToEndRoomKeyStore(SQLBaseStore): + @defer.inlineCallbacks + def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces the encrypted E2E room key for a given session in a given backup + + Args: + user_id(str): the user whose backup we're setting + version(str): the version ID of the backup we're updating + room_id(str): the ID of the room whose keys we're setting + session_id(str): the session whose room_key we're setting + room_key(dict): the room_key being set + Raises: + StoreError + """ + + yield self.db_pool.simple_update_one( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + }, + updatevalues={ + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), + }, + desc="update_e2e_room_key", + ) + + @defer.inlineCallbacks + def add_e2e_room_keys(self, user_id, version, room_keys): + """Bulk add room keys to a given backup. + + Args: + user_id (str): the user whose backup we're adding to + version (str): the version ID of the backup for the set of keys we're adding to + room_keys (iterable[(str, str, dict)]): the keys to add, in the form + (roomID, sessionID, keyData) + """ + + values = [] + for (room_id, session_id, room_key) in room_keys: + values.append( + { + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), + } + ) + log_kv( + { + "message": "Set room key", + "room_id": room_id, + "session_id": session_id, + "room_key": room_key, + } + ) + + yield self.db_pool.simple_insert_many( + table="e2e_room_keys", values=values, desc="add_e2e_room_keys" + ) + + @trace + @defer.inlineCallbacks + def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + """Bulk get the E2E room keys for a given backup, optionally filtered to a given + room, or a given session. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup for the set of keys we're querying + room_id (str): Optional. the ID of the room whose keys we're querying, if any. + If not specified, we return the keys for all the rooms in the backup. + session_id (str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we return all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred list of dicts giving the session_data and message metadata for + these room keys. + """ + + try: + version = int(version) + except ValueError: + return {"rooms": {}} + + keyvalues = {"user_id": user_id, "version": version} + if room_id: + keyvalues["room_id"] = room_id + if session_id: + keyvalues["session_id"] = session_id + + rows = yield self.db_pool.simple_select_list( + table="e2e_room_keys", + keyvalues=keyvalues, + retcols=( + "user_id", + "room_id", + "session_id", + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_keys", + ) + + sessions = {"rooms": {}} + for row in rows: + room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) + room_entry["sessions"][row["session_id"]] = { + "first_message_index": row["first_message_index"], + "forwarded_count": row["forwarded_count"], + # is_verified must be returned to the client as a boolean + "is_verified": bool(row["is_verified"]), + "session_data": db_to_json(row["session_data"]), + } + + return sessions + + def get_e2e_room_keys_multi(self, user_id, version, room_keys): + """Get multiple room keys at a time. The difference between this function and + get_e2e_room_keys is that this function can be used to retrieve + multiple specific keys at a time, whereas get_e2e_room_keys is used for + getting all the keys in a backup version, all the keys for a room, or a + specific key. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + room_keys (dict[str, dict[str, iterable[str]]]): a map from + room ID -> {"session": [session ids]} indicating the session IDs + that we want to query + + Returns: + Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + """ + + return self.db_pool.runInteraction( + "get_e2e_room_keys_multi", + self._get_e2e_room_keys_multi_txn, + user_id, + version, + room_keys, + ) + + @staticmethod + def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): + if not room_keys: + return {} + + where_clauses = [] + params = [user_id, version] + for room_id, room in room_keys.items(): + sessions = list(room["sessions"]) + if not sessions: + continue + params.append(room_id) + params.extend(sessions) + where_clauses.append( + "(room_id = ? AND session_id IN (%s))" + % (",".join(["?" for _ in sessions]),) + ) + + # check if we're actually querying something + if not where_clauses: + return {} + + sql = """ + SELECT room_id, session_id, first_message_index, forwarded_count, + is_verified, session_data + FROM e2e_room_keys + WHERE user_id = ? AND version = ? AND (%s) + """ % ( + " OR ".join(where_clauses) + ) + + txn.execute(sql, params) + + ret = {} + + for row in txn: + room_id = row[0] + session_id = row[1] + ret.setdefault(room_id, {}) + ret[room_id][session_id] = { + "first_message_index": row[2], + "forwarded_count": row[3], + "is_verified": row[4], + "session_data": db_to_json(row[5]), + } + + return ret + + def count_e2e_room_keys(self, user_id, version): + """Get the number of keys in a backup version. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + """ + + return self.db_pool.simple_select_one_onecol( + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": version}, + retcol="COUNT(*)", + desc="count_e2e_room_keys", + ) + + @trace + @defer.inlineCallbacks + def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + """Bulk delete the E2E room keys for a given backup, optionally filtered to a given + room or a given session. + + Args: + user_id(str): the user whose backup we're deleting from + version(str): the version ID of the backup for the set of keys we're deleting + room_id(str): Optional. the ID of the room whose keys we're deleting, if any. + If not specified, we delete the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we delete all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred of the deletion transaction + """ + + keyvalues = {"user_id": user_id, "version": int(version)} + if room_id: + keyvalues["room_id"] = room_id + if session_id: + keyvalues["session_id"] = session_id + + yield self.db_pool.simple_delete( + table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" + ) + + @staticmethod + def _get_current_version(txn, user_id): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions " + "WHERE user_id=? AND deleted=0", + (user_id,), + ) + row = txn.fetchone() + if not row: + raise StoreError(404, "No current backup version") + return row[0] + + def get_e2e_room_keys_version_info(self, user_id, version=None): + """Get info metadata about a version of our room_keys backup. + + Args: + user_id(str): the user whose backup we're querying + version(str): Optional. the version ID of the backup we're querying about + If missing, we return the information about the current version. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present + Returns: + A deferred dict giving the info metadata for this backup version, with + fields including: + version(str) + algorithm(str) + auth_data(object): opaque dict supplied by the client + etag(int): tag of the keys in the backup + """ + + def _get_e2e_room_keys_version_info_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + else: + try: + this_version = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it isn't there. + raise StoreError(404, "No row found") + + result = self.db_pool.simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, + retcols=("version", "algorithm", "auth_data", "etag"), + ) + result["auth_data"] = db_to_json(result["auth_data"]) + result["version"] = str(result["version"]) + if result["etag"] is None: + result["etag"] = 0 + return result + + return self.db_pool.runInteraction( + "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn + ) + + @trace + def create_e2e_room_keys_version(self, user_id, info): + """Atomically creates a new version of this user's e2e_room_keys store + with the given version info. + + Args: + user_id(str): the user whose backup we're creating a version + info(dict): the info about the backup version to be created + + Returns: + A deferred string for the newly created version ID + """ + + def _create_e2e_room_keys_version_txn(txn): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", + (user_id,), + ) + current_version = txn.fetchone()[0] + if current_version is None: + current_version = "0" + + new_version = str(int(current_version) + 1) + + self.db_pool.simple_insert_txn( + txn, + table="e2e_room_keys_versions", + values={ + "user_id": user_id, + "version": new_version, + "algorithm": info["algorithm"], + "auth_data": json.dumps(info["auth_data"]), + }, + ) + + return new_version + + return self.db_pool.runInteraction( + "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn + ) + + @trace + def update_e2e_room_keys_version( + self, user_id, version, info=None, version_etag=None + ): + """Update a given backup version + + Args: + user_id(str): the user whose backup version we're updating + version(str): the version ID of the backup version we're updating + info (dict): the new backup version info to store. If None, then + the backup version info is not updated + version_etag (Optional[int]): etag of the keys in the backup. If + None, then the etag is not updated + """ + updatevalues = {} + + if info is not None and "auth_data" in info: + updatevalues["auth_data"] = json.dumps(info["auth_data"]) + if version_etag is not None: + updatevalues["etag"] = version_etag + + if updatevalues: + return self.db_pool.simple_update( + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": version}, + updatevalues=updatevalues, + desc="update_e2e_room_keys_version", + ) + + @trace + def delete_e2e_room_keys_version(self, user_id, version=None): + """Delete a given backup version of the user's room keys. + Doesn't delete their actual key data. + + Args: + user_id(str): the user whose backup version we're deleting + version(str): Optional. the version ID of the backup version we're deleting + If missing, we delete the current backup version info. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present, + or if the version requested doesn't exist. + """ + + def _delete_e2e_room_keys_version_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + if this_version is None: + raise StoreError(404, "No current backup version") + else: + this_version = version + + self.db_pool.simple_delete_txn( + txn, + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": this_version}, + ) + + return self.db_pool.simple_update_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": this_version}, + updatevalues={"deleted": 1}, + ) + + return self.db_pool.runInteraction( + "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn + ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py new file mode 100644 index 0000000000..40354b8304 --- /dev/null +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -0,0 +1,748 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple + +from canonicaljson import encode_canonical_json, json + +from twisted.enterprise.adbapi import Connection +from twisted.internet import defer + +from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import make_in_list_sql_clause +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter + + +class EndToEndKeyWorkerStore(SQLBaseStore): + @trace + @defer.inlineCallbacks + def get_e2e_device_keys( + self, query_list, include_all_devices=False, include_deleted_devices=False + ): + """Fetch a list of device keys. + Args: + query_list(list): List of pairs of user_ids and device_ids. + include_all_devices (bool): whether to include entries for devices + that don't have device keys + include_deleted_devices (bool): whether to include null entries for + devices which no longer exist (but were in the query_list). + This option only takes effect if include_all_devices is true. + Returns: + Dict mapping from user-id to dict mapping from device_id to + key data. The key data will be a dict in the same format as the + DeviceKeys type returned by POST /_matrix/client/r0/keys/query. + """ + set_tag("query_list", query_list) + if not query_list: + return {} + + results = yield self.db_pool.runInteraction( + "get_e2e_device_keys", + self._get_e2e_device_keys_txn, + query_list, + include_all_devices, + include_deleted_devices, + ) + + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section + rv = {} + for user_id, device_keys in results.items(): + rv[user_id] = {} + for device_id, device_info in device_keys.items(): + r = db_to_json(device_info.pop("key_json")) + r["unsigned"] = {} + display_name = device_info["device_display_name"] + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + if "signatures" in device_info: + for sig_user_id, sigs in device_info["signatures"].items(): + r.setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + rv[user_id][device_id] = r + + return rv + + @trace + def _get_e2e_device_keys_txn( + self, txn, query_list, include_all_devices=False, include_deleted_devices=False + ): + set_tag("include_all_devices", include_all_devices) + set_tag("include_deleted_devices", include_deleted_devices) + + query_clauses = [] + query_params = [] + signature_query_clauses = [] + signature_query_params = [] + + if include_all_devices is False: + include_deleted_devices = False + + if include_deleted_devices: + deleted_devices = set(query_list) + + for (user_id, device_id) in query_list: + query_clause = "user_id = ?" + query_params.append(user_id) + signature_query_clause = "target_user_id = ?" + signature_query_params.append(user_id) + + if device_id is not None: + query_clause += " AND device_id = ?" + query_params.append(device_id) + signature_query_clause += " AND target_device_id = ?" + signature_query_params.append(device_id) + + signature_query_clause += " AND user_id = ?" + signature_query_params.append(user_id) + + query_clauses.append(query_clause) + signature_query_clauses.append(signature_query_clause) + + sql = ( + "SELECT user_id, device_id, " + " d.display_name AS device_display_name, " + " k.key_json" + " FROM devices d" + " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" + " WHERE %s AND NOT d.hidden" + ) % ( + "LEFT" if include_all_devices else "INNER", + " OR ".join("(" + q + ")" for q in query_clauses), + ) + + txn.execute(sql, query_params) + rows = self.db_pool.cursor_to_dict(txn) + + result = {} + for row in rows: + if include_deleted_devices: + deleted_devices.remove((row["user_id"], row["device_id"])) + result.setdefault(row["user_id"], {})[row["device_id"]] = row + + if include_deleted_devices: + for user_id, device_id in deleted_devices: + result.setdefault(user_id, {})[device_id] = None + + # get signatures on the device + signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % ( + " OR ".join("(" + q + ")" for q in signature_query_clauses) + ) + + txn.execute(signature_sql, signature_query_params) + rows = self.db_pool.cursor_to_dict(txn) + + # add each cross-signing signature to the correct device in the result dict. + for row in rows: + signing_user_id = row["user_id"] + signing_key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + signature = row["signature"] + + target_user_result = result.get(target_user_id) + if not target_user_result: + continue + + target_device_result = target_user_result.get(target_device_id) + if not target_device_result: + # note that target_device_result will be None for deleted devices. + continue + + target_device_signatures = target_device_result.setdefault("signatures", {}) + signing_user_signatures = target_device_signatures.setdefault( + signing_user_id, {} + ) + signing_user_signatures[signing_key_id] = signature + + log_kv(result) + return result + + @defer.inlineCallbacks + def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + """Retrieve a number of one-time keys for a user + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + key_ids(list[str]): list of key ids (excluding algorithm) to + retrieve + + Returns: + deferred resolving to Dict[(str, str), str]: map from (algorithm, + key_id) to json string for key + """ + + rows = yield self.db_pool.simple_select_many_batch( + table="e2e_one_time_keys_json", + column="key_id", + iterable=key_ids, + retcols=("algorithm", "key_id", "key_json"), + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="add_e2e_one_time_keys_check", + ) + result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} + log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) + return result + + @defer.inlineCallbacks + def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + """Insert some new one time keys for a device. Errors if any of the + keys already exist. + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + time_now(long): insertion time to record (ms since epoch) + new_keys(iterable[(str, str, str)]: keys to add - each a tuple of + (algorithm, key_id, key json) + """ + + def _add_e2e_one_time_keys(txn): + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("new_keys", new_keys) + # We are protected from race between lookup and insertion due to + # a unique constraint. If there is a race of two calls to + # `add_e2e_one_time_keys` then they'll conflict and we will only + # insert one set. + self.db_pool.simple_insert_many_txn( + txn, + table="e2e_one_time_keys_json", + values=[ + { + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + "ts_added_ms": time_now, + "key_json": json_bytes, + } + for algorithm, key_id, json_bytes in new_keys + ], + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + yield self.db_pool.runInteraction( + "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys + ) + + @cached(max_entries=10000) + def count_e2e_one_time_keys(self, user_id, device_id): + """ Count the number of one time keys the server has for a device + Returns: + Dict mapping from algorithm to number of keys for that algorithm. + """ + + def _count_e2e_one_time_keys(txn): + sql = ( + "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ?" + " GROUP BY algorithm" + ) + txn.execute(sql, (user_id, device_id)) + result = {} + for algorithm, key_count in txn: + result[algorithm] = key_count + return result + + return self.db_pool.runInteraction( + "count_e2e_one_time_keys", _count_e2e_one_time_keys + ) + + @defer.inlineCallbacks + def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + from_user_id (str): if specified, signatures made by this user on + the self-signing key will be included in the result + + Returns: + dict of the key data or None if not found + """ + res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) + user_keys = res.get(user_id) + if not user_keys: + return None + return user_keys.get(key_type) + + @cached(num_args=1) + def _get_bare_e2e_cross_signing_keys(self, user_id): + """Dummy function. Only used to make a cache for + _get_bare_e2e_cross_signing_keys_bulk. + """ + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_bare_e2e_cross_signing_keys", + list_name="user_ids", + num_args=1, + ) + def _get_bare_e2e_cross_signing_keys_bulk( + self, user_ids: List[str] + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + + """ + return self.db_pool.runInteraction( + "get_bare_e2e_cross_signing_keys_bulk", + self._get_bare_e2e_cross_signing_keys_bulk_txn, + user_ids, + ) + + def _get_bare_e2e_cross_signing_keys_bulk_txn( + self, txn: Connection, user_ids: List[str], + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, their user + ID will not be in the dict. + + """ + result = {} + + for user_chunk in batch_iter(user_ids, 100): + clause, params = make_in_list_sql_clause( + txn.database_engine, "k.user_id", user_chunk + ) + sql = ( + """ + SELECT k.user_id, k.keytype, k.keydata, k.stream_id + FROM e2e_cross_signing_keys k + INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id + FROM e2e_cross_signing_keys + GROUP BY user_id, keytype) s + USING (user_id, stream_id, keytype) + WHERE + """ + + clause + ) + + txn.execute(sql, params) + rows = self.db_pool.cursor_to_dict(txn) + + for row in rows: + user_id = row["user_id"] + key_type = row["keytype"] + key = db_to_json(row["keydata"]) + user_info = result.setdefault(user_id, {}) + user_info[key_type] = key + + return result + + def _get_e2e_cross_signing_signatures_txn( + self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing signatures made by a user on a set of keys. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + keys (dict[str, dict[str, dict]]): a map of user ID to key type to + key data. This dict will be modified to add signatures. + from_user_id (str): fetch the signatures made by this user + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. The return value will be the same as the keys argument, + with the modifications included. + """ + + # find out what cross-signing keys (a.k.a. devices) we need to get + # signatures for. This is a map of (user_id, device_id) to key type + # (device_id is the key's public part). + devices = {} + + for user_id, user_info in keys.items(): + if user_info is None: + continue + for key_type, key in user_info.items(): + device_id = None + for k in key["keys"].values(): + device_id = k + devices[(user_id, device_id)] = key_type + + for batch in batch_iter(devices.keys(), size=100): + sql = """ + SELECT target_user_id, target_device_id, key_id, signature + FROM e2e_cross_signing_signatures + WHERE user_id = ? + AND (%s) + """ % ( + " OR ".join( + "(target_user_id = ? AND target_device_id = ?)" for _ in batch + ) + ) + query_params = [from_user_id] + for item in batch: + # item is a (user_id, device_id) tuple + query_params.extend(item) + + txn.execute(sql, query_params) + rows = self.db_pool.cursor_to_dict(txn) + + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + key_type = devices[(target_user_id, target_device_id)] + # We need to copy everything, because the result may have come + # from the cache. dict.copy only does a shallow copy, so we + # need to recursively copy the dicts that will be modified. + user_info = keys[target_user_id] = keys[target_user_id].copy() + target_user_key = user_info[key_type] = user_info[key_type].copy() + if "signatures" in target_user_key: + signatures = target_user_key["signatures"] = target_user_key[ + "signatures" + ].copy() + if from_user_id in signatures: + user_sigs = signatures[from_user_id] = signatures[from_user_id] + user_sigs[key_id] = row["signature"] + else: + signatures[from_user_id] = {key_id: row["signature"]} + else: + target_user_key["signatures"] = { + from_user_id: {key_id: row["signature"]} + } + + return keys + + @defer.inlineCallbacks + def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: str = None + ) -> defer.Deferred: + """Returns the cross-signing keys for a set of users. + + Args: + user_ids (list[str]): the users whose keys are being requested + from_user_id (str): if specified, signatures made by this user on + the self-signing keys will be included in the result + + Returns: + Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to + key data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + """ + + result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + + if from_user_id: + result = yield self.db_pool.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ) + + return result + + async def get_all_user_signature_changes_for_remotes( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for groups replication stream. + + Note that the user signature stream represents when a user signs their + device with their user-signing key, which is not published to other + users or servers, so no `destination` is needed in the returned + list. However, this is needed to poke workers. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def _get_all_user_signature_changes_for_remotes_txn(txn): + sql = """ + SELECT stream_id, from_user_id AS user_id + FROM user_signature_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + + updates = [(row[0], (row[1:])) for row in txn] + + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_all_user_signature_changes_for_remotes", + _get_all_user_signature_changes_for_remotes_txn, + ) + + +class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + + def _set_e2e_device_keys_txn(txn): + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("time_now", time_now) + set_tag("device_keys", device_keys) + + old_key_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="key_json", + allow_none=True, + ) + + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") + + if old_key_json == new_key_json: + log_kv({"Message": "Device key already stored."}) + return False + + self.db_pool.simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, + ) + log_kv({"message": "Device keys stored."}) + return True + + return self.db_pool.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn + ) + + def claim_e2e_one_time_keys(self, query_list): + """Take a list of one time keys out of the database""" + + @trace + def _claim_e2e_one_time_keys(txn): + sql = ( + "SELECT key_id, key_json FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) + result = {} + delete = [] + for user_id, device_id, algorithm in query_list: + user_result = result.setdefault(user_id, {}) + device_result = user_result.setdefault(device_id, {}) + txn.execute(sql, (user_id, device_id, algorithm)) + for key_id, key_json in txn: + device_result[algorithm + ":" + key_id] = key_json + delete.append((user_id, device_id, algorithm, key_id)) + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " AND key_id = ?" + ) + for user_id, device_id, algorithm, key_id in delete: + log_kv( + { + "message": "Executing claim e2e_one_time_keys transaction on database." + } + ) + txn.execute(sql, (user_id, device_id, algorithm, key_id)) + log_kv({"message": "finished executing and invalidating cache"}) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + return result + + return self.db_pool.runInteraction( + "claim_e2e_one_time_keys", _claim_e2e_one_time_keys + ) + + def delete_e2e_keys_by_device(self, user_id, device_id): + def delete_e2e_keys_by_device_txn(txn): + log_kv( + { + "message": "Deleting keys for device", + "device_id": device_id, + "user_id": user_id, + } + ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + return self.db_pool.runInteraction( + "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn + ) + + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + """Set a user's cross-signing key. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_id (str): the user to set the signing key for + key_type (str): the type of key that is being set: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + key (dict): the key data + """ + # the 'key' dict will look something like: + # { + # "user_id": "@alice:example.com", + # "usage": ["self_signing"], + # "keys": { + # "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key", + # }, + # "signatures": { + # "@alice:example.com": { + # "ed25519:base64+master+public+key": "base64+signature" + # } + # } + # } + # The "keys" property must only have one entry, which will be the public + # key, so we just grab the first value in there + pubkey = next(iter(key["keys"].values())) + + # The cross-signing keys need to occupy the same namespace as devices, + # since signatures are identified by device ID. So add an entry to the + # device table to make sure that we don't have a collision with device + # IDs. + # We only need to do this for local users, since remote servers should be + # responsible for checking this for their own users. + if self.hs.is_mine_id(user_id): + self.db_pool.simple_insert_txn( + txn, + "devices", + values={ + "user_id": user_id, + "device_id": pubkey, + "display_name": key_type + " signing key", + "hidden": True, + }, + ) + + # and finally, store the key itself + with self._cross_signing_id_gen.get_next() as stream_id: + self.db_pool.simple_insert_txn( + txn, + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json.dumps(key), + "stream_id": stream_id, + }, + ) + + self._invalidate_cache_and_stream( + txn, self._get_bare_e2e_cross_signing_keys, (user_id,) + ) + + def set_e2e_cross_signing_key(self, user_id, key_type, key): + """Set a user's cross-signing key. + + Args: + user_id (str): the user to set the user-signing key for + key_type (str): the type of cross-signing key to set + key (dict): the key data + """ + return self.db_pool.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + ) + + def store_e2e_cross_signing_signatures(self, user_id, signatures): + """Stores cross-signing signatures. + + Args: + user_id (str): the user who made the signatures + signatures (iterable[SignatureListItem]): signatures to add + """ + return self.db_pool.simple_insert_many( + "e2e_cross_signing_signatures", + [ + { + "user_id": user_id, + "key_id": item.signing_key_id, + "target_user_id": item.target_user_id, + "target_device_id": item.target_device_id, + "signature": item.signature, + } + for item in signatures + ], + "add_e2e_signing_key", + ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py new file mode 100644 index 0000000000..eddb32b4d3 --- /dev/null +++ b/synapse/storage/databases/main/event_federation.py @@ -0,0 +1,726 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import logging +from queue import Empty, PriorityQueue +from typing import Dict, List, Optional, Set, Tuple + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.util.caches.descriptors import cached +from synapse.util.iterutils import batch_iter + +logger = logging.getLogger(__name__) + + +class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): + def get_auth_chain(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. + + Args: + event_ids (list): state events + include_given (bool): include the given events in result + + Returns: + list of events + """ + return self.get_auth_chain_ids( + event_ids, include_given=include_given + ).addCallback(self.get_events_as_list) + + def get_auth_chain_ids( + self, + event_ids: List[str], + include_given: bool = False, + ignore_events: Optional[Set[str]] = None, + ): + """Get auth events for given event_ids. The events *must* be state events. + + Args: + event_ids: state events + include_given: include the given events in result + ignore_events: Set of events to exclude from the returned auth + chain. This is useful if the caller will just discard the + given events anyway, and saves us from figuring out their auth + chains if not required. + + Returns: + list of event_ids + """ + return self.db_pool.runInteraction( + "get_auth_chain_ids", + self._get_auth_chain_ids_txn, + event_ids, + include_given, + ignore_events, + ) + + def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): + if ignore_events is None: + ignore_events = set() + + if include_given: + results = set(event_ids) + else: + results = set() + + base_sql = "SELECT auth_id FROM event_auth WHERE " + + front = set(event_ids) + while front: + new_front = set() + for chunk in batch_iter(front, 100): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", chunk + ) + txn.execute(base_sql + clause, args) + new_front.update(r[0] for r in txn) + + new_front -= ignore_events + new_front -= results + + front = new_front + results.update(front) + + return list(results) + + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). + + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. + + Returns: + Deferred[Set[str]] + """ + + return self.db_pool.runInteraction( + "get_auth_chain_difference", + self._get_auth_chain_difference_txn, + state_sets, + ) + + def _get_auth_chain_difference_txn( + self, txn, state_sets: List[Set[str]] + ) -> Set[str]: + + # Algorithm Description + # ~~~~~~~~~~~~~~~~~~~~~ + # + # The idea here is to basically walk the auth graph of each state set in + # tandem, keeping track of which auth events are reachable by each state + # set. If we reach an auth event we've already visited (via a different + # state set) then we mark that auth event and all ancestors as reachable + # by the state set. This requires that we keep track of the auth chains + # in memory. + # + # Doing it in a such a way means that we can stop early if all auth + # events we're currently walking are reachable by all state sets. + # + # *Note*: We can't stop walking an event's auth chain if it is reachable + # by all state sets. This is because other auth chains we're walking + # might be reachable only via the original auth chain. For example, + # given the following auth chain: + # + # A -> C -> D -> E + # / / + # B -´---------´ + # + # and state sets {A} and {B} then walking the auth chains of A and B + # would immediately show that C is reachable by both. However, if we + # stopped at C then we'd only reach E via the auth chain of B and so E + # would errornously get included in the returned difference. + # + # The other thing that we do is limit the number of auth chains we walk + # at once, due to practical limits (i.e. we can only query the database + # with a limited set of parameters). We pick the auth chains we walk + # each iteration based on their depth, in the hope that events with a + # lower depth are likely reachable by those with higher depths. + # + # We could use any ordering that we believe would give a rough + # topological ordering, e.g. origin server timestamp. If the ordering + # chosen is not topological then the algorithm still produces the right + # result, but perhaps a bit more inefficiently. This is why it is safe + # to use "depth" here. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Dict from events in auth chains to which sets *cannot* reach them. + # I.e. if the set is empty then all sets can reach the event. + event_to_missing_sets = { + event_id: {i for i, a in enumerate(state_sets) if event_id not in a} + for event_id in initial_events + } + + # The sorted list of events whose auth chains we should walk. + search = [] # type: List[Tuple[int, str]] + + # We need to get the depth of the initial events for sorting purposes. + sql = """ + SELECT depth, event_id FROM events + WHERE %s + """ + # the list can be huge, so let's avoid looking them all up in one massive + # query. + for batch in batch_iter(initial_events, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) + + # I think building a temporary list with fetchall is more efficient than + # just `search.extend(txn)`, but this is unconfirmed + search.extend(txn.fetchall()) + + # sort by depth + search.sort() + + # Map from event to its auth events + event_to_auth_events = {} # type: Dict[str, Set[str]] + + base_sql = """ + SELECT a.event_id, auth_id, depth + FROM event_auth AS a + INNER JOIN events AS e ON (e.event_id = a.auth_id) + WHERE + """ + + while search: + # Check whether all our current walks are reachable by all state + # sets. If so we can bail. + if all(not event_to_missing_sets[eid] for _, eid in search): + break + + # Fetch the auth events and their depths of the N last events we're + # currently walking + search, chunk = search[:-100], search[-100:] + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] + ) + txn.execute(base_sql + clause, args) + + for event_id, auth_event_id, auth_event_depth in txn: + event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) + + sets = event_to_missing_sets.get(auth_event_id) + if sets is None: + # First time we're seeing this event, so we add it to the + # queue of things to fetch. + search.append((auth_event_depth, auth_event_id)) + + # Assume that this event is unreachable from any of the + # state sets until proven otherwise + sets = event_to_missing_sets[auth_event_id] = set( + range(len(state_sets)) + ) + else: + # We've previously seen this event, so look up its auth + # events and recursively mark all ancestors as reachable + # by the current event's state set. + a_ids = event_to_auth_events.get(auth_event_id) + while a_ids: + new_aids = set() + for a_id in a_ids: + event_to_missing_sets[a_id].intersection_update( + event_to_missing_sets[event_id] + ) + + b = event_to_auth_events.get(a_id) + if b: + new_aids.update(b) + + a_ids = new_aids + + # Mark that the auth event is reachable by the approriate sets. + sets.intersection_update(event_to_missing_sets[event_id]) + + search.sort() + + # Return all events where not all sets can reach them. + return {eid for eid, n in event_to_missing_sets.items() if n} + + def get_oldest_events_in_room(self, room_id): + return self.db_pool.runInteraction( + "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id + ) + + def get_oldest_events_with_depth_in_room(self, room_id): + return self.db_pool.runInteraction( + "get_oldest_events_with_depth_in_room", + self.get_oldest_events_with_depth_in_room_txn, + room_id, + ) + + def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): + sql = ( + "SELECT b.event_id, MAX(e.depth) FROM events as e" + " INNER JOIN event_edges as g" + " ON g.event_id = e.event_id" + " INNER JOIN event_backward_extremities as b" + " ON g.prev_event_id = b.event_id" + " WHERE b.room_id = ? AND g.is_state is ?" + " GROUP BY b.event_id" + ) + + txn.execute(sql, (room_id, False)) + + return dict(txn) + + @defer.inlineCallbacks + def get_max_depth_of(self, event_ids): + """Returns the max depth of a set of event IDs + + Args: + event_ids (list[str]) + + Returns + Deferred[int] + """ + rows = yield self.db_pool.simple_select_many_batch( + table="events", + column="event_id", + iterable=event_ids, + retcols=("depth",), + desc="get_max_depth_of", + ) + + if not rows: + return 0 + else: + return max(row["depth"] for row in rows) + + def _get_oldest_events_in_room_txn(self, txn, room_id): + return self.db_pool.simple_select_onecol_txn( + txn, + table="event_backward_extremities", + keyvalues={"room_id": room_id}, + retcol="event_id", + ) + + def get_prev_events_for_room(self, room_id: str): + """ + Gets a subset of the current forward extremities in the given room. + + Limits the result to 10 extremities, so that we can avoid creating + events which refer to hundreds of prev_events. + + Args: + room_id (str): room_id + + Returns: + Deferred[List[str]]: the event ids of the forward extremites + + """ + + return self.db_pool.runInteraction( + "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id + ) + + def _get_prev_events_for_room_txn(self, txn, room_id: str): + # we just use the 10 newest events. Older events will become + # prev_events of future events. + + sql = """ + SELECT e.event_id FROM event_forward_extremities AS f + INNER JOIN events AS e USING (event_id) + WHERE f.room_id = ? + ORDER BY e.depth DESC + LIMIT 10 + """ + + txn.execute(sql, (room_id,)) + + return [row[0] for row in txn] + + def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): + """Get the top rooms with at least N extremities. + + Args: + min_count (int): The minimum number of extremities + limit (int): The maximum number of rooms to return. + room_id_filter (iterable[str]): room_ids to exclude from the results + + Returns: + Deferred[list]: At most `limit` room IDs that have at least + `min_count` extremities, sorted by extremity count. + """ + + def _get_rooms_with_many_extremities_txn(txn): + where_clause = "1=1" + if room_id_filter: + where_clause = "room_id NOT IN (%s)" % ( + ",".join("?" for _ in room_id_filter), + ) + + sql = """ + SELECT room_id FROM event_forward_extremities + WHERE %s + GROUP BY room_id + HAVING count(*) > ? + ORDER BY count(*) DESC + LIMIT ? + """ % ( + where_clause, + ) + + query_args = list(itertools.chain(room_id_filter, [min_count, limit])) + txn.execute(sql, query_args) + return [room_id for room_id, in txn] + + return self.db_pool.runInteraction( + "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn + ) + + @cached(max_entries=5000, iterable=True) + def get_latest_event_ids_in_room(self, room_id): + return self.db_pool.simple_select_onecol( + table="event_forward_extremities", + keyvalues={"room_id": room_id}, + retcol="event_id", + desc="get_latest_event_ids_in_room", + ) + + def get_min_depth(self, room_id): + """ For hte given room, get the minimum depth we have seen for it. + """ + return self.db_pool.runInteraction( + "get_min_depth", self._get_min_depth_interaction, room_id + ) + + def _get_min_depth_interaction(self, txn, room_id): + min_depth = self.db_pool.simple_select_one_onecol_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id}, + retcol="min_depth", + allow_none=True, + ) + + return int(min_depth) if min_depth is not None else None + + def get_forward_extremeties_for_room(self, room_id, stream_ordering): + """For a given room_id and stream_ordering, return the forward + extremeties of the room at that point in "time". + + Throws a StoreError if we have since purged the index for + stream_orderings from that point. + + Args: + room_id (str): + stream_ordering (int): + + Returns: + deferred, which resolves to a list of event_ids + """ + # We want to make the cache more effective, so we clamp to the last + # change before the given ordering. + last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) + + # We don't always have a full stream_to_exterm_id table, e.g. after + # the upgrade that introduced it, so we make sure we never ask for a + # stream_ordering from before a restart + last_change = max(self._stream_order_on_start, last_change) + + # provided the last_change is recent enough, we now clamp the requested + # stream_ordering to it. + if last_change > self.stream_ordering_month_ago: + stream_ordering = min(last_change, stream_ordering) + + return self._get_forward_extremeties_for_room(room_id, stream_ordering) + + @cached(max_entries=5000, num_args=2) + def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + """For a given room_id and stream_ordering, return the forward + extremeties of the room at that point in "time". + + Throws a StoreError if we have since purged the index for + stream_orderings from that point. + """ + + if stream_ordering <= self.stream_ordering_month_ago: + raise StoreError(400, "stream_ordering too old") + + sql = """ + SELECT event_id FROM stream_ordering_to_exterm + INNER JOIN ( + SELECT room_id, MAX(stream_ordering) AS stream_ordering + FROM stream_ordering_to_exterm + WHERE stream_ordering <= ? GROUP BY room_id + ) AS rms USING (room_id, stream_ordering) + WHERE room_id = ? + """ + + def get_forward_extremeties_for_room_txn(txn): + txn.execute(sql, (stream_ordering, room_id)) + return [event_id for event_id, in txn] + + return self.db_pool.runInteraction( + "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn + ) + + def get_backfill_events(self, room_id, event_list, limit): + """Get a list of Events for a given topic that occurred before (and + including) the events in event_list. Return a list of max size `limit` + + Args: + txn + room_id (str) + event_list (list) + limit (int) + """ + return ( + self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, + ) + .addCallback(self.get_events_as_list) + .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + ) + + def _get_backfill_events(self, txn, room_id, event_list, limit): + logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) + + event_results = set() + + # We want to make sure that we do a breadth-first, "depth" ordered + # search. + + query = ( + "SELECT depth, prev_event_id FROM event_edges" + " INNER JOIN events" + " ON prev_event_id = events.event_id" + " WHERE event_edges.event_id = ?" + " AND event_edges.is_state = ?" + " LIMIT ?" + ) + + queue = PriorityQueue() + + for event_id in event_list: + depth = self.db_pool.simple_select_one_onecol_txn( + txn, + table="events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcol="depth", + allow_none=True, + ) + + if depth: + queue.put((-depth, event_id)) + + while not queue.empty() and len(event_results) < limit: + try: + _, event_id = queue.get_nowait() + except Empty: + break + + if event_id in event_results: + continue + + event_results.add(event_id) + + txn.execute(query, (event_id, False, limit - len(event_results))) + + for row in txn: + if row[1] not in event_results: + queue.put((-row[0], row[1])) + + return event_results + + @defer.inlineCallbacks + def get_missing_events(self, room_id, earliest_events, latest_events, limit): + ids = yield self.db_pool.runInteraction( + "get_missing_events", + self._get_missing_events, + room_id, + earliest_events, + latest_events, + limit, + ) + events = yield self.get_events_as_list(ids) + return events + + def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): + + seen_events = set(earliest_events) + front = set(latest_events) - seen_events + event_results = [] + + query = ( + "SELECT prev_event_id FROM event_edges " + "WHERE room_id = ? AND event_id = ? AND is_state = ? " + "LIMIT ?" + ) + + while front and len(event_results) < limit: + new_front = set() + for event_id in front: + txn.execute( + query, (room_id, event_id, False, limit - len(event_results)) + ) + + new_results = {t[0] for t in txn} - seen_events + + new_front |= new_results + seen_events |= new_results + event_results.extend(new_results) + + front = new_front + + # we built the list working backwards from latest_events; we now need to + # reverse it so that the events are approximately chronological. + event_results.reverse() + return event_results + + @defer.inlineCallbacks + def get_successor_events(self, event_ids): + """Fetch all events that have the given events as a prev event + + Args: + event_ids (iterable[str]) + + Returns: + Deferred[list[str]] + """ + rows = yield self.db_pool.simple_select_many_batch( + table="event_edges", + column="prev_event_id", + iterable=event_ids, + retcols=("event_id",), + desc="get_successor_events", + ) + + return [row["event_id"] for row in rows] + + +class EventFederationStore(EventFederationWorkerStore): + """ Responsible for storing and serving up the various graphs associated + with an event. Including the main event graph and the auth chains for an + event. + + Also has methods for getting the front (latest) and back (oldest) edges + of the event graphs. These are used to generate the parents for new events + and backfilling from another server respectively. + """ + + EVENT_AUTH_STATE_ONLY = "event_auth_state_only" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(EventFederationStore, self).__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth + ) + + hs.get_clock().looping_call( + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + ) + + def _delete_old_forward_extrem_cache(self): + def _delete_old_forward_extrem_cache_txn(txn): + # Delete entries older than a month, while making sure we don't delete + # the only entries for a room. + sql = """ + DELETE FROM stream_ordering_to_exterm + WHERE + room_id IN ( + SELECT room_id + FROM stream_ordering_to_exterm + WHERE stream_ordering > ? + ) AND stream_ordering < ? + """ + txn.execute( + sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) + ) + + return run_as_background_process( + "delete_old_forward_extrem_cache", + self.db_pool.runInteraction, + "_delete_old_forward_extrem_cache", + _delete_old_forward_extrem_cache_txn, + ) + + def clean_room_for_join(self, room_id): + return self.db_pool.runInteraction( + "clean_room_for_join", self._clean_room_for_join_txn, room_id + ) + + def _clean_room_for_join_txn(self, txn, room_id): + query = "DELETE FROM event_forward_extremities WHERE room_id = ?" + + txn.execute(query, (room_id,)) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) + + @defer.inlineCallbacks + def _background_delete_non_state_event_auth(self, progress, batch_size): + def delete_event_auth(txn): + target_min_stream_id = progress.get("target_min_stream_id_inclusive") + max_stream_id = progress.get("max_stream_id_exclusive") + + if not target_min_stream_id or not max_stream_id: + txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events") + rows = txn.fetchall() + target_min_stream_id = rows[0][0] + + txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events") + rows = txn.fetchall() + max_stream_id = rows[0][0] + + min_stream_id = max_stream_id - batch_size + + sql = """ + DELETE FROM event_auth + WHERE event_id IN ( + SELECT event_id FROM events + LEFT JOIN state_events USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND state_key IS null + ) + """ + + txn.execute(sql, (min_stream_id, max_stream_id)) + + new_progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self.db_pool.updates._background_update_progress_txn( + txn, self.EVENT_AUTH_STATE_ONLY, new_progress + ) + + return min_stream_id >= target_min_stream_id + + result = yield self.db_pool.runInteraction( + self.EVENT_AUTH_STATE_ONLY, delete_event_auth + ) + + if not result: + yield self.db_pool.updates._end_background_update( + self.EVENT_AUTH_STATE_ONLY + ) + + return batch_size diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py new file mode 100644 index 0000000000..b8cefb4d5e --- /dev/null +++ b/synapse/storage/databases/main/event_push_actions.py @@ -0,0 +1,885 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List + +from canonicaljson import json + +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.util.caches.descriptors import cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] +DEFAULT_HIGHLIGHT_ACTION = [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, +] + + +def _serialize_action(actions, is_highlight): + """Custom serializer for actions. This allows us to "compress" common actions. + + We use the fact that most users have the same actions for notifs (and for + highlights). + We store these default actions as the empty string rather than the full JSON. + Since the empty string isn't valid JSON there is no risk of this clashing with + any real JSON actions + """ + if is_highlight: + if actions == DEFAULT_HIGHLIGHT_ACTION: + return "" # We use empty string as the column is non-NULL + else: + if actions == DEFAULT_NOTIF_ACTION: + return "" + return json.dumps(actions) + + +def _deserialize_action(actions, is_highlight): + """Custom deserializer for actions. This allows us to "compress" common actions + """ + if actions: + return db_to_json(actions) + + if is_highlight: + return DEFAULT_HIGHLIGHT_ACTION + else: + return DEFAULT_NOTIF_ACTION + + +class EventPushActionsWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) + + # These get correctly set by _find_stream_orderings_for_times_txn + self.stream_ordering_month_ago = None + self.stream_ordering_day_ago = None + + cur = LoggingTransaction( + db_conn.cursor(), + name="_find_stream_orderings_for_times_txn", + database_engine=self.database_engine, + ) + self._find_stream_orderings_for_times_txn(cur) + cur.close() + + self.find_stream_orderings_looping_call = self._clock.looping_call( + self._find_stream_orderings_for_times, 10 * 60 * 1000 + ) + self._rotate_delay = 3 + self._rotate_count = 10000 + + @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) + def get_unread_event_push_actions_by_room_for_user( + self, room_id, user_id, last_read_event_id + ): + ret = yield self.db_pool.runInteraction( + "get_unread_event_push_actions_by_room", + self._get_unread_counts_by_receipt_txn, + room_id, + user_id, + last_read_event_id, + ) + return ret + + def _get_unread_counts_by_receipt_txn( + self, txn, room_id, user_id, last_read_event_id + ): + sql = ( + "SELECT stream_ordering" + " FROM events" + " WHERE room_id = ? AND event_id = ?" + ) + txn.execute(sql, (room_id, last_read_event_id)) + results = txn.fetchall() + if len(results) == 0: + return {"notify_count": 0, "highlight_count": 0} + + stream_ordering = results[0][0] + + return self._get_unread_counts_by_pos_txn( + txn, room_id, user_id, stream_ordering + ) + + def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): + + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " user_id = ?" + " AND room_id = ?" + " AND stream_ordering > ?" + ) + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + notify_count = row[0] if row else 0 + + txn.execute( + """ + SELECT notif_count FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + """, + (room_id, user_id, stream_ordering), + ) + rows = txn.fetchall() + if rows: + notify_count += rows[0][0] + + # Now get the number of highlights + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " highlight = 1" + " AND user_id = ?" + " AND room_id = ?" + " AND stream_ordering > ?" + ) + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + highlight_count = row[0] if row else 0 + + return {"notify_count": notify_count, "highlight_count": highlight_count} + + async def get_push_action_users_in_range( + self, min_stream_ordering, max_stream_ordering + ): + def f(txn): + sql = ( + "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" + " stream_ordering >= ? AND stream_ordering <= ?" + ) + txn.execute(sql, (min_stream_ordering, max_stream_ordering)) + return [r[0] for r in txn] + + ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f) + return ret + + async def get_unread_push_actions_for_user_in_range_for_http( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: + """Get a list of the most recent unread push actions for a given user, + within the given stream ordering range. Called by the httppusher. + + Args: + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the + stream ordering of event push actions to fetch. + max_stream_ordering: The inclusive upper bound on the + stream ordering of event push actions to fetch. + limit: The maximum number of rows to return. + Returns: + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions". + The list will be ordered by ascending stream_ordering. + The list will have between 0~limit entries. + """ + # find rooms that have a read receipt in them and return the next + # push actions + def get_after_receipt(txn): + # find rooms that have a read receipt in them and return the next + # push actions + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight " + " FROM (" + " SELECT room_id," + " MAX(stream_ordering) as stream_ordering" + " FROM events" + " INNER JOIN receipts_linearized USING (room_id, event_id)" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + ") AS rl," + " event_push_actions AS ep" + " WHERE" + " ep.room_id = rl.room_id" + " AND ep.stream_ordering > rl.stream_ordering" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering ASC LIMIT ?" + ) + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + txn.execute(sql, args) + return txn.fetchall() + + after_read_receipt = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt + ) + + # There are rooms with push actions in them but you don't have a read receipt in + # them e.g. rooms you've been invited to, so get push actions for rooms which do + # not have read receipts in them too. + def get_no_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight " + " FROM event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id NOT IN (" + " SELECT room_id FROM receipts_linearized" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering ASC LIMIT ?" + ) + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + txn.execute(sql, args) + return txn.fetchall() + + no_read_receipt = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt + ) + + notifs = [ + { + "event_id": row[0], + "room_id": row[1], + "stream_ordering": row[2], + "actions": _deserialize_action(row[3], row[4]), + } + for row in after_read_receipt + no_read_receipt + ] + + # Now sort it so it's ordered correctly, since currently it will + # contain results from the first query, correctly ordered, followed + # by results from the second query, but we want them all ordered + # by stream_ordering, oldest first. + notifs.sort(key=lambda r: r["stream_ordering"]) + + # Take only up to the limit. We have to stop at the limit because + # one of the subqueries may have hit the limit. + return notifs[:limit] + + async def get_unread_push_actions_for_user_in_range_for_email( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: + """Get a list of the most recent unread push actions for a given user, + within the given stream ordering range. Called by the emailpusher + + Args: + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the + stream ordering of event push actions to fetch. + max_stream_ordering: The inclusive upper bound on the + stream ordering of event push actions to fetch. + limit: The maximum number of rows to return. + Returns: + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". + The list will be ordered by descending received_ts. + The list will have between 0~limit entries. + """ + # find rooms that have a read receipt in them and return the most recent + # push actions + def get_after_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight, e.received_ts" + " FROM (" + " SELECT room_id," + " MAX(stream_ordering) as stream_ordering" + " FROM events" + " INNER JOIN receipts_linearized USING (room_id, event_id)" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + ") AS rl," + " event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id = rl.room_id" + " AND ep.stream_ordering > rl.stream_ordering" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering DESC LIMIT ?" + ) + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + txn.execute(sql, args) + return txn.fetchall() + + after_read_receipt = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt + ) + + # There are rooms with push actions in them but you don't have a read receipt in + # them e.g. rooms you've been invited to, so get push actions for rooms which do + # not have read receipts in them too. + def get_no_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight, e.received_ts" + " FROM event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id NOT IN (" + " SELECT room_id FROM receipts_linearized" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering DESC LIMIT ?" + ) + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + txn.execute(sql, args) + return txn.fetchall() + + no_read_receipt = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt + ) + + # Make a list of dicts from the two sets of results. + notifs = [ + { + "event_id": row[0], + "room_id": row[1], + "stream_ordering": row[2], + "actions": _deserialize_action(row[3], row[4]), + "received_ts": row[5], + } + for row in after_read_receipt + no_read_receipt + ] + + # Now sort it so it's ordered correctly, since currently it will + # contain results from the first query, correctly ordered, followed + # by results from the second query, but we want them all ordered + # by received_ts (most recent first) + notifs.sort(key=lambda r: -(r["received_ts"] or 0)) + + # Now return the first `limit` + return notifs[:limit] + + def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): + """A fast check to see if there might be something to push for the + user since the given stream ordering. May return false positives. + + Useful to know whether to bother starting a pusher on start up or not. + + Args: + user_id (str) + min_stream_ordering (int) + + Returns: + Deferred[bool]: True if there may be push to process, False if + there definitely isn't. + """ + + def _get_if_maybe_push_in_range_for_user_txn(txn): + sql = """ + SELECT 1 FROM event_push_actions + WHERE user_id = ? AND stream_ordering > ? + LIMIT 1 + """ + + txn.execute(sql, (user_id, min_stream_ordering)) + return bool(txn.fetchone()) + + return self.db_pool.runInteraction( + "get_if_maybe_push_in_range_for_user", + _get_if_maybe_push_in_range_for_user_txn, + ) + + async def add_push_actions_to_staging(self, event_id, user_id_actions): + """Add the push actions for the event to the push action staging area. + + Args: + event_id (str) + user_id_actions (dict[str, list[dict|str])]): A dictionary mapping + user_id to list of push actions, where an action can either be + a string or dict. + + Returns: + Deferred + """ + + if not user_id_actions: + return + + # This is a helper function for generating the necessary tuple that + # can be used to inert into the `event_push_actions_staging` table. + def _gen_entry(user_id, actions): + is_highlight = 1 if _action_has_highlight(actions) else 0 + return ( + event_id, # event_id column + user_id, # user_id column + _serialize_action(actions, is_highlight), # actions column + 1, # notif column + is_highlight, # highlight column + ) + + def _add_push_actions_to_staging_txn(txn): + # We don't use simple_insert_many here to avoid the overhead + # of generating lists of dicts. + + sql = """ + INSERT INTO event_push_actions_staging + (event_id, user_id, actions, notif, highlight) + VALUES (?, ?, ?, ?, ?) + """ + + txn.executemany( + sql, + ( + _gen_entry(user_id, actions) + for user_id, actions in user_id_actions.items() + ), + ) + + return await self.db_pool.runInteraction( + "add_push_actions_to_staging", _add_push_actions_to_staging_txn + ) + + async def remove_push_actions_from_staging(self, event_id: str) -> None: + """Called if we failed to persist the event to ensure that stale push + actions don't build up in the DB + """ + + try: + res = await self.db_pool.simple_delete( + table="event_push_actions_staging", + keyvalues={"event_id": event_id}, + desc="remove_push_actions_from_staging", + ) + return res + except Exception: + # this method is called from an exception handler, so propagating + # another exception here really isn't helpful - there's nothing + # the caller can do about it. Just log the exception and move on. + logger.exception( + "Error removing push actions after event persistence failure" + ) + + def _find_stream_orderings_for_times(self): + return run_as_background_process( + "event_push_action_stream_orderings", + self.db_pool.runInteraction, + "_find_stream_orderings_for_times", + self._find_stream_orderings_for_times_txn, + ) + + def _find_stream_orderings_for_times_txn(self, txn): + logger.info("Searching for stream ordering 1 month ago") + self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago + ) + logger.info("Searching for stream ordering 1 day ago") + self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago + ) + + def find_first_stream_ordering_after_ts(self, ts): + """Gets the stream ordering corresponding to a given timestamp. + + Specifically, finds the stream_ordering of the first event that was + received on or after the timestamp. This is done by a binary search on + the events table, since there is no index on received_ts, so is + relatively slow. + + Args: + ts (int): timestamp in millis + + Returns: + Deferred[int]: stream ordering of the first event received on/after + the timestamp + """ + return self.db_pool.runInteraction( + "_find_first_stream_ordering_after_ts_txn", + self._find_first_stream_ordering_after_ts_txn, + ts, + ) + + @staticmethod + def _find_first_stream_ordering_after_ts_txn(txn, ts): + """ + Find the stream_ordering of the first event that was received on or + after a given timestamp. This is relatively slow as there is no index + on received_ts but we can then use this to delete push actions before + this. + + received_ts must necessarily be in the same order as stream_ordering + and stream_ordering is indexed, so we manually binary search using + stream_ordering + + Args: + txn (twisted.enterprise.adbapi.Transaction): + ts (int): timestamp to search for + + Returns: + int: stream ordering + """ + txn.execute("SELECT MAX(stream_ordering) FROM events") + max_stream_ordering = txn.fetchone()[0] + + if max_stream_ordering is None: + return 0 + + # We want the first stream_ordering in which received_ts is greater + # than or equal to ts. Call this point X. + # + # We maintain the invariants: + # + # range_start <= X <= range_end + # + range_start = 0 + range_end = max_stream_ordering + 1 + + # Given a stream_ordering, look up the timestamp at that + # stream_ordering. + # + # The array may be sparse (we may be missing some stream_orderings). + # We treat the gaps as the same as having the same value as the + # preceding entry, because we will pick the lowest stream_ordering + # which satisfies our requirement of received_ts >= ts. + # + # For example, if our array of events indexed by stream_ordering is + # [10, , 20], we should treat this as being equivalent to + # [10, 10, 20]. + # + sql = ( + "SELECT received_ts FROM events" + " WHERE stream_ordering <= ?" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + + while range_end - range_start > 0: + middle = (range_end + range_start) // 2 + txn.execute(sql, (middle,)) + row = txn.fetchone() + if row is None: + # no rows with stream_ordering<=middle + range_start = middle + 1 + continue + + middle_ts = row[0] + if ts > middle_ts: + # we got a timestamp lower than the one we were looking for. + # definitely need to look higher: X > middle. + range_start = middle + 1 + else: + # we got a timestamp higher than (or the same as) the one we + # were looking for. We aren't yet sure about the point we + # looked up, but we can be sure that X <= middle. + range_end = middle + + return range_end + + async def get_time_of_last_push_action_before(self, stream_ordering): + def f(txn): + sql = ( + "SELECT e.received_ts" + " FROM event_push_actions AS ep" + " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" + " WHERE ep.stream_ordering > ?" + " ORDER BY ep.stream_ordering ASC" + " LIMIT 1" + ) + txn.execute(sql, (stream_ordering,)) + return txn.fetchone() + + result = await self.db_pool.runInteraction( + "get_time_of_last_push_action_before", f + ) + return result[0] if result else None + + +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(EventPushActionsStore, self).__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.db_pool.updates.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1", + ) + + self._doing_notif_rotation = False + self._rotate_notif_loop = self._clock.looping_call( + self._start_rotate_notifs, 30 * 60 * 1000 + ) + + async def get_push_actions_for_user( + self, user_id, before=None, limit=50, only_highlight=False + ): + def f(txn): + before_clause = "" + if before: + before_clause = "AND epa.stream_ordering < ?" + args = [user_id, before, limit] + else: + args = [user_id, limit] + + if only_highlight: + if len(before_clause) > 0: + before_clause += " " + before_clause += "AND epa.highlight = 1" + + # NB. This assumes event_ids are globally unique since + # it makes the query easier to index + sql = ( + "SELECT epa.event_id, epa.room_id," + " epa.stream_ordering, epa.topological_ordering," + " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" + " FROM event_push_actions epa, events e" + " WHERE epa.event_id = e.event_id" + " AND epa.user_id = ? %s" + " ORDER BY epa.stream_ordering DESC" + " LIMIT ?" % (before_clause,) + ) + txn.execute(sql, args) + return self.db_pool.cursor_to_dict(txn) + + push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) + for pa in push_actions: + pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) + return push_actions + + async def get_latest_push_action_stream_ordering(self): + def f(txn): + txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") + return txn.fetchone() + + result = await self.db_pool.runInteraction( + "get_latest_push_action_stream_ordering", f + ) + return result[0] or 0 + + def _remove_old_push_actions_before_txn( + self, txn, room_id, user_id, stream_ordering + ): + """ + Purges old push actions for a user and room before a given + stream_ordering. + + We however keep a months worth of highlighted notifications, so that + users can still get a list of recent highlights. + + Args: + txn: The transcation + room_id: Room ID to delete from + user_id: user ID to delete for + stream_ordering: The lowest stream ordering which will + not be deleted. + """ + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id, user_id), + ) + + # We need to join on the events table to get the received_ts for + # event_push_actions and sqlite won't let us use a join in a delete so + # we can't just delete where received_ts < x. Furthermore we can + # only identify event_push_actions by a tuple of room_id, event_id + # we we can't use a subquery. + # Instead, we look up the stream ordering for the last event in that + # room received before the threshold time and delete event_push_actions + # in the room with a stream_odering before that. + txn.execute( + "DELETE FROM event_push_actions " + " WHERE user_id = ? AND room_id = ? AND " + " stream_ordering <= ?" + " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", + (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), + ) + + txn.execute( + """ + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, + (room_id, user_id, stream_ordering), + ) + + def _start_rotate_notifs(self): + return run_as_background_process("rotate_notifs", self._rotate_notifs) + + async def _rotate_notifs(self): + if self._doing_notif_rotation or self.stream_ordering_day_ago is None: + return + self._doing_notif_rotation = True + + try: + while True: + logger.info("Rotating notifications") + + caught_up = await self.db_pool.runInteraction( + "_rotate_notifs", self._rotate_notifs_txn + ) + if caught_up: + break + await self.hs.get_clock().sleep(self._rotate_delay) + finally: + self._doing_notif_rotation = False + + def _rotate_notifs_txn(self, txn): + """Archives older notifications into event_push_summary. Returns whether + the archiving process has caught up or not. + """ + + old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + # We don't to try and rotate millions of rows at once, so we cap the + # maximum stream ordering we'll rotate before. + txn.execute( + """ + SELECT stream_ordering FROM event_push_actions + WHERE stream_ordering > ? + ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? + """, + (old_rotate_stream_ordering, self._rotate_count), + ) + stream_row = txn.fetchone() + if stream_row: + (offset_stream_ordering,) = stream_row + rotate_to_stream_ordering = min( + self.stream_ordering_day_ago, offset_stream_ordering + ) + caught_up = offset_stream_ordering >= self.stream_ordering_day_ago + else: + rotate_to_stream_ordering = self.stream_ordering_day_ago + caught_up = True + + logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) + + self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) + + # We have caught up iff we were limited by `stream_ordering_day_ago` + return caught_up + + def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): + old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + # Calculate the new counts that should be upserted into event_push_summary + sql = """ + SELECT user_id, room_id, + coalesce(old.notif_count, 0) + upd.notif_count, + upd.stream_ordering, + old.user_id + FROM ( + SELECT user_id, room_id, count(*) as notif_count, + max(stream_ordering) as stream_ordering + FROM event_push_actions + WHERE ? <= stream_ordering AND stream_ordering < ? + AND highlight = 0 + GROUP BY user_id, room_id + ) AS upd + LEFT JOIN event_push_summary AS old USING (user_id, room_id) + """ + + txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) + rows = txn.fetchall() + + logger.info("Rotating notifications, handling %d rows", len(rows)) + + # If the `old.user_id` above is NULL then we know there isn't already an + # entry in the table, so we simply insert it. Otherwise we update the + # existing table. + self.db_pool.simple_insert_many_txn( + txn, + table="event_push_summary", + values=[ + { + "user_id": row[0], + "room_id": row[1], + "notif_count": row[2], + "stream_ordering": row[3], + } + for row in rows + if row[4] is None + ], + ) + + txn.executemany( + """ + UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? + WHERE user_id = ? AND room_id = ? + """, + ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), + ) + + txn.execute( + "DELETE FROM event_push_actions" + " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) + + txn.execute( + "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", + (rotate_to_stream_ordering,), + ) + + +def _action_has_highlight(actions): + for action in actions: + try: + if action.get("set_tweak", None) == "highlight": + return action.get("value", True) + except AttributeError: + pass + + return False diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py new file mode 100644 index 0000000000..4d8a24ce4b --- /dev/null +++ b/synapse/storage/databases/main/events.py @@ -0,0 +1,1527 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import logging +from collections import OrderedDict, namedtuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple + +import attr +from prometheus_client import Counter + +from twisted.internet import defer + +import synapse.metrics +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.room_versions import RoomVersions +from synapse.crypto.event_signing import compute_event_reference_hash +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.logging.utils import log_function +from synapse.storage._base import db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.search import SearchEntry +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import StateMap, get_domain_from_id +from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util.iterutils import batch_iter + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore + + +logger = logging.getLogger(__name__) + +persist_event_counter = Counter("synapse_storage_events_persisted_events", "") +event_counter = Counter( + "synapse_storage_events_persisted_events_sep", + "", + ["type", "origin_type", "origin_entity"], +) + +STATE_EVENT_TYPES_TO_MARK_UNREAD = { + EventTypes.Topic, + EventTypes.Name, + EventTypes.RoomAvatar, + EventTypes.Tombstone, +} + + +def should_count_as_unread(event: EventBase, context: EventContext) -> bool: + # Exclude rejected and soft-failed events. + if context.rejected or event.internal_metadata.is_soft_failed(): + return False + + # Exclude notices. + if ( + not event.is_state() + and event.type == EventTypes.Message + and event.content.get("msgtype") == "m.notice" + ): + return False + + # Exclude edits. + relates_to = event.content.get("m.relates_to", {}) + if relates_to.get("rel_type") == RelationTypes.REPLACE: + return False + + # Mark events that have a non-empty string body as unread. + body = event.content.get("body") + if isinstance(body, str) and body: + return True + + # Mark some state events as unread. + if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: + return True + + # Mark encrypted events as unread. + if not event.is_state() and event.type == EventTypes.Encrypted: + return True + + return False + + +def encode_json(json_object): + """ + Encode a Python object as JSON and return it in a Unicode string. + """ + out = frozendict_json_encoder.encode(json_object) + if isinstance(out, bytes): + out = out.decode("utf8") + return out + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +@attr.s(slots=True) +class DeltaState: + """Deltas to use to update the `current_state_events` table. + + Attributes: + to_delete: List of type/state_keys to delete from current state + to_insert: Map of state to upsert into current state + no_longer_in_room: The server is not longer in the room, so the room + should e.g. be removed from `current_state_events` table. + """ + + to_delete = attr.ib(type=List[Tuple[str, str]]) + to_insert = attr.ib(type=StateMap[str]) + no_longer_in_room = attr.ib(type=bool, default=False) + + +class PersistEventsStore: + """Contains all the functions for writing events to the database. + + Should only be instantiated on one process (when using a worker mode setup). + + Note: This is not part of the `DataStore` mixin. + """ + + def __init__( + self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore" + ): + self.hs = hs + self.db_pool = db + self.store = main_data_store + self.database_engine = db.engine + self._clock = hs.get_clock() + + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + self.is_mine_id = hs.is_mine_id + + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator + self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator + + # This should only exist on instances that are configured to write + assert ( + hs.config.worker.writers.events == hs.get_instance_name() + ), "Can only instantiate EventsStore on master" + + @defer.inlineCallbacks + def _persist_events_and_state_updates( + self, + events_and_contexts: List[Tuple[EventBase, EventContext]], + current_state_for_room: Dict[str, StateMap[str]], + state_delta_for_room: Dict[str, DeltaState], + new_forward_extremeties: Dict[str, List[str]], + backfilled: bool = False, + ): + """Persist a set of events alongside updates to the current state and + forward extremities tables. + + Args: + events_and_contexts: + current_state_for_room: Map from room_id to the current state of + the room based on forward extremities + state_delta_for_room: Map from room_id to the delta to apply to + room state + new_forward_extremities: Map from room_id to list of event IDs + that are the new forward extremities of the room. + backfilled + + Returns: + Deferred: resolves when the events have been persisted + """ + + # We want to calculate the stream orderings as late as possible, as + # we only notify after all events with a lesser stream ordering have + # been persisted. I.e. if we spend 10s inside the with block then + # that will delay all subsequent events from being notified about. + # Hence why we do it down here rather than wrapping the entire + # function. + # + # Its safe to do this after calculating the state deltas etc as we + # only need to protect the *persistence* of the events. This is to + # ensure that queries of the form "fetch events since X" don't + # return events and stream positions after events that are still in + # flight, as otherwise subsequent requests "fetch event since Y" + # will not return those events. + # + # Note: Multiple instances of this function cannot be in flight at + # the same time for the same room. + if backfilled: + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) + else: + stream_ordering_manager = self._stream_id_gen.get_next_mult( + len(events_and_contexts) + ) + + with stream_ordering_manager as stream_orderings: + for (event, context), stream in zip(events_and_contexts, stream_orderings): + event.internal_metadata.stream_ordering = stream + + yield self.db_pool.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=events_and_contexts, + backfilled=backfilled, + state_delta_for_room=state_delta_for_room, + new_forward_extremeties=new_forward_extremeties, + ) + persist_event_counter.inc(len(events_and_contexts)) + + if not backfilled: + # backfilled events have negative stream orderings, so we don't + # want to set the event_persisted_position to that. + synapse.metrics.event_persisted_position.set( + events_and_contexts[-1][0].internal_metadata.stream_ordering + ) + + for event, context in events_and_contexts: + if context.app_service: + origin_type = "local" + origin_entity = context.app_service.id + elif self.hs.is_mine_id(event.sender): + origin_type = "local" + origin_entity = "*client*" + else: + origin_type = "remote" + origin_entity = get_domain_from_id(event.sender) + + event_counter.labels(event.type, origin_type, origin_entity).inc() + + self.store.get_unread_message_count_for_user.invalidate_many( + (event.room_id,), + ) + + for room_id, new_state in current_state_for_room.items(): + self.store.get_current_state_ids.prefill((room_id,), new_state) + + for room_id, latest_event_ids in new_forward_extremeties.items(): + self.store.get_latest_event_ids_in_room.prefill( + (room_id,), list(latest_event_ids) + ) + + @defer.inlineCallbacks + def _get_events_which_are_prevs(self, event_ids): + """Filter the supplied list of event_ids to get those which are prev_events of + existing (non-outlier/rejected) events. + + Args: + event_ids (Iterable[str]): event ids to filter + + Returns: + Deferred[List[str]]: filtered event ids + """ + results = [] + + def _get_events_which_are_prevs_txn(txn, batch): + sql = """ + SELECT prev_event_id, internal_metadata + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) + WHERE + NOT events.outlier + AND rejections.event_id IS NULL + AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "prev_event_id", batch + ) + + txn.execute(sql + clause, args) + results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) + + for chunk in batch_iter(event_ids, 100): + yield self.db_pool.runInteraction( + "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk + ) + + return results + + @defer.inlineCallbacks + def _get_prevs_before_rejected(self, event_ids): + """Get soft-failed ancestors to remove from the extremities. + + Given a set of events, find all those that have been soft-failed or + rejected. Returns those soft failed/rejected events and their prev + events (whether soft-failed/rejected or not), and recurses up the + prev-event graph until it finds no more soft-failed/rejected events. + + This is used to find extremities that are ancestors of new events, but + are separated by soft failed events. + + Args: + event_ids (Iterable[str]): Events to find prev events for. Note + that these must have already been persisted. + + Returns: + Deferred[set[str]] + """ + + # The set of event_ids to return. This includes all soft-failed events + # and their prev events. + existing_prevs = set() + + def _get_prevs_before_rejected_txn(txn, batch): + to_recursively_check = batch + + while to_recursively_check: + sql = """ + SELECT + event_id, prev_event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) + WHERE + NOT events.outlier + AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", to_recursively_check + ) + + txn.execute(sql + clause, args) + to_recursively_check = [] + + for event_id, prev_event_id, metadata, rejected in txn: + if prev_event_id in existing_prevs: + continue + + soft_failed = db_to_json(metadata).get("soft_failed") + if soft_failed or rejected: + to_recursively_check.append(prev_event_id) + existing_prevs.add(prev_event_id) + + for chunk in batch_iter(event_ids, 100): + yield self.db_pool.runInteraction( + "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk + ) + + return existing_prevs + + @log_function + def _persist_events_txn( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, + state_delta_for_room: Dict[str, DeltaState] = {}, + new_forward_extremeties: Dict[str, List[str]] = {}, + ): + """Insert some number of room events into the necessary database tables. + + Rejected events are only inserted into the events table, the events_json table, + and the rejections table. Things reading from those table will need to check + whether the event was rejected. + + Args: + txn + events_and_contexts: events to persist + backfilled: True if the events were backfilled + delete_existing True to purge existing table rows for the events + from the database. This is useful when retrying due to + IntegrityError. + state_delta_for_room: The current-state delta for each room. + new_forward_extremetie: The new forward extremities for each room. + For each room, a list of the event ids which are the forward + extremities. + + """ + all_events_and_contexts = events_and_contexts + + min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + + self._update_forward_extremities_txn( + txn, + new_forward_extremities=new_forward_extremeties, + max_stream_order=max_stream_order, + ) + + # Ensure that we don't have the same event twice. + events_and_contexts = self._filter_events_and_contexts_for_duplicates( + events_and_contexts + ) + + self._update_room_depths_txn( + txn, events_and_contexts=events_and_contexts, backfilled=backfilled + ) + + # _update_outliers_txn filters out any events which have already been + # persisted, and returns the filtered list. + events_and_contexts = self._update_outliers_txn( + txn, events_and_contexts=events_and_contexts + ) + + # From this point onwards the events are only events that we haven't + # seen before. + + self._store_event_txn(txn, events_and_contexts=events_and_contexts) + + # Insert into event_to_state_groups. + self._store_event_state_mappings_txn(txn, events_and_contexts) + + # We want to store event_auth mappings for rejected events, as they're + # used in state res v2. + # This is only necessary if the rejected event appears in an accepted + # event's auth chain, but its easier for now just to store them (and + # it doesn't take much storage compared to storing the entire event + # anyway). + self.db_pool.simple_insert_many_txn( + txn, + table="event_auth", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "auth_id": auth_id, + } + for event, _ in events_and_contexts + for auth_id in event.auth_event_ids() + if event.is_state() + ], + ) + + # _store_rejected_events_txn filters out any events which were + # rejected, and returns the filtered list. + events_and_contexts = self._store_rejected_events_txn( + txn, events_and_contexts=events_and_contexts + ) + + # From this point onwards the events are only ones that weren't + # rejected. + + self._update_metadata_tables_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + backfilled=backfilled, + ) + + # We call this last as it assumes we've inserted the events into + # room_memberships, where applicable. + self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + + def _update_current_state_txn( + self, + txn: LoggingTransaction, + state_delta_by_room: Dict[str, DeltaState], + stream_id: int, + ): + for room_id, delta_state in state_delta_by_room.items(): + to_delete = delta_state.to_delete + to_insert = delta_state.to_insert + + if delta_state.no_longer_in_room: + # Server is no longer in the room so we delete the room from + # current_state_events, being careful we've already updated the + # rooms.room_version column (which gets populated in a + # background task). + self._upsert_room_version_txn(txn, room_id) + + # Before deleting we populate the current_state_delta_stream + # so that async background tasks get told what happened. + sql = """ + INSERT INTO current_state_delta_stream + (stream_id, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, room_id, type, state_key, null, event_id + FROM current_state_events + WHERE room_id = ? + """ + txn.execute(sql, (stream_id, room_id)) + + self.db_pool.simple_delete_txn( + txn, table="current_state_events", keyvalues={"room_id": room_id}, + ) + else: + # We're still in the room, so we update the current state as normal. + + # First we add entries to the current_state_delta_stream. We + # do this before updating the current_state_events table so + # that we can use it to calculate the `prev_event_id`. (This + # allows us to not have to pull out the existing state + # unnecessarily). + # + # The stream_id for the update is chosen to be the minimum of the stream_ids + # for the batch of the events that we are persisting; that means we do not + # end up in a situation where workers see events before the + # current_state_delta updates. + # + sql = """ + INSERT INTO current_state_delta_stream + (stream_id, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, ?, ?, ?, ?, ( + SELECT event_id FROM current_state_events + WHERE room_id = ? AND type = ? AND state_key = ? + ) + """ + txn.executemany( + sql, + ( + ( + stream_id, + room_id, + etype, + state_key, + to_insert.get((etype, state_key)), + room_id, + etype, + state_key, + ) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) + # Now we actually update the current_state_events table + + txn.executemany( + "DELETE FROM current_state_events" + " WHERE room_id = ? AND type = ? AND state_key = ?", + ( + (room_id, etype, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) + + # We include the membership in the current state table, hence we do + # a lookup when we insert. This assumes that all events have already + # been inserted into room_memberships. + txn.executemany( + """INSERT INTO current_state_events + (room_id, type, state_key, event_id, membership) + VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + """, + [ + (room_id, key[0], key[1], ev_id, ev_id) + for key, ev_id in to_insert.items() + ], + ) + + # We now update `local_current_membership`. We do this regardless + # of whether we're still in the room or not to handle the case where + # e.g. we just got banned (where we need to record that fact here). + + # Note: Do we really want to delete rows here (that we do not + # subsequently reinsert below)? While technically correct it means + # we have no record of the fact the user *was* a member of the + # room but got, say, state reset out of it. + if to_delete or to_insert: + txn.executemany( + "DELETE FROM local_current_membership" + " WHERE room_id = ? AND user_id = ?", + ( + (room_id, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + if etype == EventTypes.Member and self.is_mine_id(state_key) + ), + ) + + if to_insert: + txn.executemany( + """INSERT INTO local_current_membership + (room_id, user_id, event_id, membership) + VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + """, + [ + (room_id, key[1], ev_id, ev_id) + for key, ev_id in to_insert.items() + if key[0] == EventTypes.Member and self.is_mine_id(key[1]) + ], + ) + + txn.call_after( + self.store._curr_state_delta_stream_cache.entity_has_changed, + room_id, + stream_id, + ) + + # Invalidate the various caches + + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invlidate the caches for all + # those users. + members_changed = { + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) + if ev_type == EventTypes.Member + } + + for member in members_changed: + txn.call_after( + self.store.get_rooms_for_user_with_stream_ordering.invalidate, + (member,), + ) + + self.store._invalidate_state_caches_and_stream( + txn, room_id, members_changed + ) + + def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): + """Update the room version in the database based off current state + events. + + This is used when we're about to delete current state and we want to + ensure that the `rooms.room_version` column is up to date. + """ + + sql = """ + SELECT json FROM event_json + INNER JOIN current_state_events USING (room_id, event_id) + WHERE room_id = ? AND type = ? AND state_key = ? + """ + txn.execute(sql, (room_id, EventTypes.Create, "")) + row = txn.fetchone() + if row: + event_json = db_to_json(row[0]) + content = event_json.get("content", {}) + creator = content.get("creator") + room_version_id = content.get("room_version", RoomVersions.V1.identifier) + + self.db_pool.simple_upsert_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + values={"room_version": room_version_id}, + insertion_values={"is_public": False, "creator": creator}, + ) + + def _update_forward_extremities_txn( + self, txn, new_forward_extremities, max_stream_order + ): + for room_id, new_extrem in new_forward_extremities.items(): + self.db_pool.simple_delete_txn( + txn, table="event_forward_extremities", keyvalues={"room_id": room_id} + ) + txn.call_after( + self.store.get_latest_event_ids_in_room.invalidate, (room_id,) + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="event_forward_extremities", + values=[ + {"event_id": ev_id, "room_id": room_id} + for room_id, new_extrem in new_forward_extremities.items() + for ev_id in new_extrem + ], + ) + # We now insert into stream_ordering_to_exterm a mapping from room_id, + # new stream_ordering to new forward extremeties in the room. + # This allows us to later efficiently look up the forward extremeties + # for a room before a given stream_ordering + self.db_pool.simple_insert_many_txn( + txn, + table="stream_ordering_to_exterm", + values=[ + { + "room_id": room_id, + "event_id": event_id, + "stream_ordering": max_stream_order, + } + for room_id, new_extrem in new_forward_extremities.items() + for event_id in new_extrem + ], + ) + + @classmethod + def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): + """Ensure that we don't have the same event twice. + + Pick the earliest non-outlier if there is one, else the earliest one. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + Returns: + list[(EventBase, EventContext)]: filtered list + """ + new_events_and_contexts = OrderedDict() + for event, context in events_and_contexts: + prev_event_context = new_events_and_contexts.get(event.event_id) + if prev_event_context: + if not event.internal_metadata.is_outlier(): + if prev_event_context[0].internal_metadata.is_outlier(): + # To ensure correct ordering we pop, as OrderedDict is + # ordered by first insertion. + new_events_and_contexts.pop(event.event_id, None) + new_events_and_contexts[event.event_id] = (event, context) + else: + new_events_and_contexts[event.event_id] = (event, context) + return list(new_events_and_contexts.values()) + + def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): + """Update min_depth for each room + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + backfilled (bool): True if the events were backfilled + """ + depth_updates = {} + for event, context in events_and_contexts: + # Remove the any existing cache entries for the event_ids + txn.call_after(self.store._invalidate_get_event_cache, event.event_id) + if not backfilled: + txn.call_after( + self.store._events_stream_cache.entity_has_changed, + event.room_id, + event.internal_metadata.stream_ordering, + ) + + if not event.internal_metadata.is_outlier() and not context.rejected: + depth_updates[event.room_id] = max( + event.depth, depth_updates.get(event.room_id, event.depth) + ) + + for room_id, depth in depth_updates.items(): + self._update_min_depth_for_room_txn(txn, room_id, depth) + + def _update_outliers_txn(self, txn, events_and_contexts): + """Update any outliers with new event info. + + This turns outliers into ex-outliers (unless the new event was + rejected). + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without events which + are already in the events table. + """ + txn.execute( + "SELECT event_id, outlier FROM events WHERE event_id in (%s)" + % (",".join(["?"] * len(events_and_contexts)),), + [event.event_id for event, _ in events_and_contexts], + ) + + have_persisted = {event_id: outlier for event_id, outlier in txn} + + to_remove = set() + for event, context in events_and_contexts: + if event.event_id not in have_persisted: + continue + + to_remove.add(event) + + if context.rejected: + # If the event is rejected then we don't care if the event + # was an outlier or not. + continue + + outlier_persisted = have_persisted[event.event_id] + if not event.internal_metadata.is_outlier() and outlier_persisted: + # We received a copy of an event that we had already stored as + # an outlier in the database. We now have some state at that + # so we need to update the state_groups table with that state. + + # insert into event_to_state_groups. + try: + self._store_event_state_mappings_txn(txn, ((event, context),)) + except Exception: + logger.exception("") + raise + + metadata_json = encode_json(event.internal_metadata.get_dict()) + + sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" + txn.execute(sql, (metadata_json, event.event_id)) + + # Add an entry to the ex_outlier_stream table to replicate the + # change in outlier status to our workers. + stream_order = event.internal_metadata.stream_ordering + state_group_id = context.state_group + self.db_pool.simple_insert_txn( + txn, + table="ex_outlier_stream", + values={ + "event_stream_ordering": stream_order, + "event_id": event.event_id, + "state_group": state_group_id, + }, + ) + + sql = "UPDATE events SET outlier = ? WHERE event_id = ?" + txn.execute(sql, (False, event.event_id)) + + # Update the event_backward_extremities table now that this + # event isn't an outlier any more. + self._update_backward_extremeties(txn, [event]) + + return [ec for ec in events_and_contexts if ec[0] not in to_remove] + + def _store_event_txn(self, txn, events_and_contexts): + """Insert new events into the event and event_json tables + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + """ + + if not events_and_contexts: + # nothing to do here + return + + def event_dict(event): + d = event.get_dict() + d.pop("redacted", None) + d.pop("redacted_because", None) + return d + + self.db_pool.simple_insert_many_txn( + txn, + table="event_json", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "internal_metadata": encode_json( + event.internal_metadata.get_dict() + ), + "json": encode_json(event_dict(event)), + "format_version": event.format_version, + } + for event, _ in events_and_contexts + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="events", + values=[ + { + "stream_ordering": event.internal_metadata.stream_ordering, + "topological_ordering": event.depth, + "depth": event.depth, + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "processed": True, + "outlier": event.internal_metadata.is_outlier(), + "origin_server_ts": int(event.origin_server_ts), + "received_ts": self._clock.time_msec(), + "sender": event.sender, + "contains_url": ( + "url" in event.content and isinstance(event.content["url"], str) + ), + "count_as_unread": should_count_as_unread(event, context), + } + for event, context in events_and_contexts + ], + ) + + for event, _ in events_and_contexts: + if not event.internal_metadata.is_redacted(): + # If we're persisting an unredacted event we go and ensure + # that we mark any redactions that reference this event as + # requiring censoring. + self.db_pool.simple_update_txn( + txn, + table="redactions", + keyvalues={"redacts": event.event_id}, + updatevalues={"have_censored": False}, + ) + + def _store_rejected_events_txn(self, txn, events_and_contexts): + """Add rows to the 'rejections' table for received events which were + rejected + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without the rejected + events. + """ + # Remove the rejected events from the list now that we've added them + # to the events table and the events_json table. + to_remove = set() + for event, context in events_and_contexts: + if context.rejected: + # Insert the event_id into the rejections table + self._store_rejections_txn(txn, event.event_id, context.rejected) + to_remove.add(event) + + return [ec for ec in events_and_contexts if ec[0] not in to_remove] + + def _update_metadata_tables_txn( + self, txn, events_and_contexts, all_events_and_contexts, backfilled + ): + """Update all the miscellaneous tables for new events + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + backfilled (bool): True if the events were backfilled + """ + + # Insert all the push actions into the event_push_actions table. + self._set_push_actions_for_event_and_users_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + ) + + if not events_and_contexts: + # nothing to do here + return + + for event, context in events_and_contexts: + if event.type == EventTypes.Redaction and event.redacts is not None: + # Remove the entries in the event_push_actions table for the + # redacted event. + self._remove_push_actions_for_event_id_txn( + txn, event.room_id, event.redacts + ) + + # Remove from relations table. + self._handle_redaction(txn, event.redacts) + + # Update the event_forward_extremities, event_backward_extremities and + # event_edges tables. + self._handle_mult_prev_events( + txn, events=[event for event, _ in events_and_contexts] + ) + + for event, _ in events_and_contexts: + if event.type == EventTypes.Name: + # Insert into the event_search table. + self._store_room_name_txn(txn, event) + elif event.type == EventTypes.Topic: + # Insert into the event_search table. + self._store_room_topic_txn(txn, event) + elif event.type == EventTypes.Message: + # Insert into the event_search table. + self._store_room_message_txn(txn, event) + elif event.type == EventTypes.Redaction and event.redacts is not None: + # Insert into the redactions table. + self._store_redaction(txn, event) + elif event.type == EventTypes.Retention: + # Update the room_retention table. + self._store_retention_policy_for_room_txn(txn, event) + + self._handle_event_relations(txn, event) + + # Store the labels for this event. + labels = event.content.get(EventContentFields.LABELS) + if labels: + self.insert_labels_for_event_txn( + txn, event.event_id, labels, event.room_id, event.depth + ) + + if self._ephemeral_messages_enabled: + # If there's an expiry timestamp on the event, store it. + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if isinstance(expiry_ts, int) and not event.is_state(): + self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) + + # Insert into the room_memberships table. + self._store_room_members_txn( + txn, + [ + event + for event, _ in events_and_contexts + if event.type == EventTypes.Member + ], + backfilled=backfilled, + ) + + # Insert event_reference_hashes table. + self._store_event_reference_hashes_txn( + txn, [event for event, _ in events_and_contexts] + ) + + state_events_and_contexts = [ + ec for ec in events_and_contexts if ec[0].is_state() + ] + + state_values = [] + for event, context in state_events_and_contexts: + vals = { + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + } + + # TODO: How does this work with backfilling? + if hasattr(event, "replaces_state"): + vals["prev_state"] = event.replaces_state + + state_values.append(vals) + + self.db_pool.simple_insert_many_txn( + txn, table="state_events", values=state_values + ) + + # Prefill the event cache + self._add_to_cache(txn, events_and_contexts) + + def _add_to_cache(self, txn, events_and_contexts): + to_prefill = [] + + rows = [] + N = 200 + for i in range(0, len(events_and_contexts), N): + ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]} + if not ev_map: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM events as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE " + ) + + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", list(ev_map) + ) + + txn.execute(sql + clause, args) + rows = self.db_pool.cursor_to_dict(txn) + for row in rows: + event = ev_map[row["event_id"]] + if not row["rejects"] and not row["redacts"]: + to_prefill.append( + _EventCacheEntry(event=event, redacted_event=None) + ) + + def prefill(): + for cache_entry in to_prefill: + self.store._get_event_cache.prefill( + (cache_entry[0].event_id,), cache_entry + ) + + txn.call_after(prefill) + + def _store_redaction(self, txn, event): + # invalidate the cache for the redacted event + txn.call_after(self.store._invalidate_get_event_cache, event.redacts) + + self.db_pool.simple_insert_txn( + txn, + table="redactions", + values={ + "event_id": event.event_id, + "redacts": event.redacts, + "received_ts": self._clock.time_msec(), + }, + ) + + def insert_labels_for_event_txn( + self, txn, event_id, labels, room_id, topological_ordering + ): + """Store the mapping between an event's ID and its labels, with one row per + (event_id, label) tuple. + + Args: + txn (LoggingTransaction): The transaction to execute. + event_id (str): The event's ID. + labels (list[str]): A list of text labels. + room_id (str): The ID of the room the event was sent to. + topological_ordering (int): The position of the event in the room's topology. + """ + return self.db_pool.simple_insert_many_txn( + txn=txn, + table="event_labels", + values=[ + { + "event_id": event_id, + "label": label, + "room_id": room_id, + "topological_ordering": topological_ordering, + } + for label in labels + ], + ) + + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. + + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. + """ + return self.db_pool.simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, + ) + + def _store_event_reference_hashes_txn(self, txn, events): + """Store a hash for a PDU + Args: + txn (cursor): + events (list): list of Events. + """ + + vals = [] + for event in events: + ref_alg, ref_hash_bytes = compute_event_reference_hash(event) + vals.append( + { + "event_id": event.event_id, + "algorithm": ref_alg, + "hash": memoryview(ref_hash_bytes), + } + ) + + self.db_pool.simple_insert_many_txn( + txn, table="event_reference_hashes", values=vals + ) + + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. + """ + self.db_pool.simple_insert_many_txn( + txn, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ], + ) + + for event in events: + txn.call_after( + self.store._membership_stream_cache.entity_has_changed, + event.state_key, + event.internal_metadata.stream_ordering, + ) + txn.call_after( + self.store.get_invited_rooms_for_local_user.invalidate, + (event.state_key,), + ) + + # We update the local_current_membership table only if the event is + # "current", i.e., its something that has just happened. + # + # This will usually get updated by the `current_state_events` handling, + # unless its an outlier, and an outlier is only "current" if it's an "out of + # band membership", like a remote invite or a rejection of a remote invite. + if ( + self.is_mine_id(event.state_key) + and not backfilled + and event.internal_metadata.is_outlier() + and event.internal_metadata.is_out_of_band_membership() + ): + self.db_pool.simple_upsert_txn( + txn, + table="local_current_membership", + keyvalues={"room_id": event.room_id, "user_id": event.state_key}, + values={ + "event_id": event.event_id, + "membership": event.membership, + }, + ) + + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events + + Args: + txn + event (EventBase) + """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return + + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): + # Unknown relation type + return + + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return + + aggregation_key = relation.get("key") + + self.db_pool.simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, + ) + + txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + ) + + if rel_type == RelationTypes.REPLACE: + txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. + + Args: + txn + redacted_event_id (str): The event that was redacted. + """ + + self.db_pool.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) + + def _store_room_topic_txn(self, txn, event): + if hasattr(event, "content") and "topic" in event.content: + self.store_event_search_txn( + txn, event, "content.topic", event.content["topic"] + ) + + def _store_room_name_txn(self, txn, event): + if hasattr(event, "content") and "name" in event.content: + self.store_event_search_txn( + txn, event, "content.name", event.content["name"] + ) + + def _store_room_message_txn(self, txn, event): + if hasattr(event, "content") and "body" in event.content: + self.store_event_search_txn( + txn, event, "content.body", event.content["body"] + ) + + def _store_retention_policy_for_room_txn(self, txn, event): + if hasattr(event, "content") and ( + "min_lifetime" in event.content or "max_lifetime" in event.content + ): + if ( + "min_lifetime" in event.content + and not isinstance(event.content.get("min_lifetime"), int) + ) or ( + "max_lifetime" in event.content + and not isinstance(event.content.get("max_lifetime"), int) + ): + # Ignore the event if one of the value isn't an integer. + return + + self.db_pool.simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "min_lifetime": event.content.get("min_lifetime"), + "max_lifetime": event.content.get("max_lifetime"), + }, + ) + + self.store._invalidate_cache_and_stream( + txn, self.store.get_retention_policy_for_room, (event.room_id,) + ) + + def store_event_search_txn(self, txn, event, key, value): + """Add event to the search table + + Args: + txn (cursor): + event (EventBase): + key (str): + value (str): + """ + self.store.store_search_entries_txn( + txn, + ( + SearchEntry( + key=key, + value=value, + event_id=event.event_id, + room_id=event.room_id, + stream_ordering=event.internal_metadata.stream_ordering, + origin_server_ts=event.origin_server_ts, + ), + ), + ) + + def _set_push_actions_for_event_and_users_txn( + self, txn, events_and_contexts, all_events_and_contexts + ): + """Handles moving push actions from staging table to main + event_push_actions table for all events in `events_and_contexts`. + + Also ensures that all events in `all_events_and_contexts` are removed + from the push action staging area. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + """ + + sql = """ + INSERT INTO event_push_actions ( + room_id, event_id, user_id, actions, stream_ordering, + topological_ordering, notif, highlight + ) + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + FROM event_push_actions_staging + WHERE event_id = ? + """ + + if events_and_contexts: + txn.executemany( + sql, + ( + ( + event.room_id, + event.internal_metadata.stream_ordering, + event.depth, + event.event_id, + ) + for event, _ in events_and_contexts + ), + ) + + for event, _ in events_and_contexts: + user_ids = self.db_pool.simple_select_onecol_txn( + txn, + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcol="user_id", + ) + + for uid in user_ids: + txn.call_after( + self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid), + ) + + # Now we delete the staging area for *all* events that were being + # persisted. + txn.executemany( + "DELETE FROM event_push_actions_staging WHERE event_id = ?", + ((event.event_id,) for event, _ in all_events_and_contexts), + ) + + def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): + # Sad that we have to blow away the cache for the whole room here + txn.call_after( + self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id,), + ) + txn.execute( + "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", + (room_id, event_id), + ) + + def _store_rejections_txn(self, txn, event_id, reason): + self.db_pool.simple_insert_txn( + txn, + table="rejections", + values={ + "event_id": event_id, + "reason": reason, + "last_check": self._clock.time_msec(), + }, + ) + + def _store_event_state_mappings_txn( + self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] + ): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.state_group_before_event + continue + + state_groups[event.event_id] = context.state_group + + self.db_pool.simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + {"state_group": state_group_id, "event_id": event_id} + for event_id, state_group_id in state_groups.items() + ], + ) + + for event_id, state_group_id in state_groups.items(): + txn.call_after( + self.store._get_state_group_for_event.prefill, + (event_id,), + state_group_id, + ) + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self.store._get_min_depth_interaction(txn, room_id) + + if min_depth is not None and depth >= min_depth: + return + + self.db_pool.simple_upsert_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id}, + values={"min_depth": depth}, + ) + + def _handle_mult_prev_events(self, txn, events): + """ + For the given event, update the event edges table and forward and + backward extremities tables. + """ + self.db_pool.simple_insert_many_txn( + txn, + table="event_edges", + values=[ + { + "event_id": ev.event_id, + "prev_event_id": e_id, + "room_id": ev.room_id, + "is_state": False, + } + for ev in events + for e_id in ev.prev_event_ids() + ], + ) + + self._update_backward_extremeties(txn, events) + + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated + events being persisted. + + This is called for new events *and* for events that were outliers, but + are now being persisted as non-outliers. + + Forward extremities are handled when we first start persisting the events. + """ + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) + + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" + ) + + txn.executemany( + query, + [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events + for e_id in ev.prev_event_ids() + if not ev.internal_metadata.is_outlier() + ], + ) + + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [ + (ev.event_id, ev.room_id) + for ev in events + if not ev.internal_metadata.is_outlier() + ], + ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py new file mode 100644 index 0000000000..35a0e09e3c --- /dev/null +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -0,0 +1,585 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventContentFields +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool + +logger = logging.getLogger(__name__) + + +class EventsBackgroundUpdatesStore(SQLBaseStore): + + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts + ) + self.db_pool.updates.register_background_update_handler( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, + self._background_reindex_fields_sender, + ) + + self.db_pool.updates.register_background_index_update( + "event_contains_url_index", + index_name="event_contains_url_index", + table="events", + columns=["room_id", "topological_ordering", "stream_ordering"], + where_clause="contains_url = true AND outlier = false", + ) + + # an event_id index on event_search is useful for the purge_history + # api. Plus it means we get to enforce some integrity with a UNIQUE + # clause + self.db_pool.updates.register_background_index_update( + "event_search_event_id_idx", + index_name="event_search_event_id_idx", + table="event_search", + columns=["event_id"], + unique=True, + psql_only=True, + ) + + self.db_pool.updates.register_background_update_handler( + self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update + ) + + self.db_pool.updates.register_background_update_handler( + "redactions_received_ts", self._redactions_received_ts + ) + + # This index gets deleted in `event_fix_redactions_bytes` update + self.db_pool.updates.register_background_index_update( + "event_fix_redactions_bytes_create_index", + index_name="redactions_censored_redacts", + table="redactions", + columns=["redacts"], + where_clause="have_censored", + ) + + self.db_pool.updates.register_background_update_handler( + "event_fix_redactions_bytes", self._event_fix_redactions_bytes + ) + + self.db_pool.updates.register_background_update_handler( + "event_store_labels", self._event_store_labels + ) + + self.db_pool.updates.register_background_index_update( + "redactions_have_censored_ts_idx", + index_name="redactions_have_censored_ts", + table="redactions", + columns=["received_ts"], + where_clause="NOT have_censored", + ) + + @defer.inlineCallbacks + def _background_reindex_fields_sender(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_txn(txn): + sql = ( + "SELECT stream_ordering, event_id, json FROM events" + " INNER JOIN event_json USING (event_id)" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + + update_rows = [] + for row in rows: + try: + event_id = row[1] + event_json = db_to_json(row[2]) + sender = event_json["sender"] + content = event_json["content"] + + contains_url = "url" in content + if contains_url: + contains_url &= isinstance(content["url"], str) + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + update_rows.append((sender, contains_url, event_id)) + + sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" + + for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): + clump = update_rows[index : index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows), + } + + self.db_pool.updates._background_update_progress_txn( + txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress + ) + + return len(rows) + + result = yield self.db_pool.runInteraction( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn + ) + + if not result: + yield self.db_pool.updates._end_background_update( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME + ) + + return result + + @defer.inlineCallbacks + def _background_reindex_origin_server_ts(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_ordering, event_id FROM events" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + event_ids = [row[1] for row in rows] + + rows_to_update = [] + + chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] + for chunk in chunks: + ev_rows = self.db_pool.simple_select_many_txn( + txn, + table="event_json", + column="event_id", + iterable=chunk, + retcols=["event_id", "json"], + keyvalues={}, + ) + + for row in ev_rows: + event_id = row["event_id"] + event_json = db_to_json(row["json"]) + try: + origin_server_ts = event_json["origin_server_ts"] + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + rows_to_update.append((origin_server_ts, event_id)) + + sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" + + for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): + clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows_to_update), + } + + self.db_pool.updates._background_update_progress_txn( + txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + ) + + return len(rows_to_update) + + result = yield self.db_pool.runInteraction( + self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn + ) + + if not result: + yield self.db_pool.updates._end_background_update( + self.EVENT_ORIGIN_SERVER_TS_NAME + ) + + return result + + @defer.inlineCallbacks + def _cleanup_extremities_bg_update(self, progress, batch_size): + """Background update to clean out extremities that should have been + deleted previously. + + Mainly used to deal with the aftermath of #5269. + """ + + # This works by first copying all existing forward extremities into the + # `_extremities_to_check` table at start up, and then checking each + # event in that table whether we have any descendants that are not + # soft-failed/rejected. If that is the case then we delete that event + # from the forward extremities table. + # + # For efficiency, we do this in batches by recursively pulling out all + # descendants of a batch until we find the non soft-failed/rejected + # events, i.e. the set of descendants whose chain of prev events back + # to the batch of extremities are all soft-failed or rejected. + # Typically, we won't find any such events as extremities will rarely + # have any descendants, but if they do then we should delete those + # extremities. + + def _cleanup_extremities_bg_update_txn(txn): + # The set of extremity event IDs that we're checking this round + original_set = set() + + # A dict[str, set[str]] of event ID to their prev events. + graph = {} + + # The set of descendants of the original set that are not rejected + # nor soft-failed. Ancestors of these events should be removed + # from the forward extremities table. + non_rejected_leaves = set() + + # Set of event IDs that have been soft failed, and for which we + # should check if they have descendants which haven't been soft + # failed. + soft_failed_events_to_lookup = set() + + # First, we get `batch_size` events from the table, pulling out + # their successor events, if any, and the successor events' + # rejection status. + txn.execute( + """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL, events.outlier + FROM ( + SELECT event_id AS prev_event_id + FROM _extremities_to_check + LIMIT ? + ) AS f + LEFT JOIN event_edges USING (prev_event_id) + LEFT JOIN events USING (event_id) + LEFT JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + """, + (batch_size,), + ) + + for prev_event_id, event_id, metadata, rejected, outlier in txn: + original_set.add(prev_event_id) + + if not event_id or outlier: + # Common case where the forward extremity doesn't have any + # descendants. + continue + + graph.setdefault(event_id, set()).add(prev_event_id) + + soft_failed = False + if metadata: + soft_failed = db_to_json(metadata).get("soft_failed") + + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # Now we recursively check all the soft-failed descendants we + # found above in the same way, until we have nothing left to + # check. + while soft_failed_events_to_lookup: + # We only want to do 100 at a time, so we split given list + # into two. + batch = list(soft_failed_events_to_lookup) + to_check, to_defer = batch[:100], batch[100:] + soft_failed_events_to_lookup = set(to_defer) + + sql = """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + INNER JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + NOT events.outlier + AND + """ + clause, args = make_in_list_sql_clause( + self.database_engine, "prev_event_id", to_check + ) + txn.execute(sql + clause, list(args)) + + for prev_event_id, event_id, metadata, rejected in txn: + if event_id in graph: + # Already handled this event previously, but we still + # want to record the edge. + graph[event_id].add(prev_event_id) + continue + + graph[event_id] = {prev_event_id} + + soft_failed = db_to_json(metadata).get("soft_failed") + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # We have a set of non-soft-failed descendants, so we recurse up + # the graph to find all ancestors and add them to the set of event + # IDs that we can delete from forward extremities table. + to_delete = set() + while non_rejected_leaves: + event_id = non_rejected_leaves.pop() + prev_event_ids = graph.get(event_id, set()) + non_rejected_leaves.update(prev_event_ids) + to_delete.update(prev_event_ids) + + to_delete.intersection_update(original_set) + + deleted = self.db_pool.simple_delete_many_txn( + txn=txn, + table="event_forward_extremities", + column="event_id", + iterable=to_delete, + keyvalues={}, + ) + + logger.info( + "Deleted %d forward extremities of %d checked, to clean up #5269", + deleted, + len(original_set), + ) + + if deleted: + # We now need to invalidate the caches of these rooms + rows = self.db_pool.simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=to_delete, + keyvalues={}, + retcols=("room_id",), + ) + room_ids = {row["room_id"] for row in rows} + for room_id in room_ids: + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, (room_id,) + ) + + self.db_pool.simple_delete_many_txn( + txn=txn, + table="_extremities_to_check", + column="event_id", + iterable=original_set, + keyvalues={}, + ) + + return len(original_set) + + num_handled = yield self.db_pool.runInteraction( + "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn + ) + + if not num_handled: + yield self.db_pool.updates._end_background_update( + self.DELETE_SOFT_FAILED_EXTREMITIES + ) + + def _drop_table_txn(txn): + txn.execute("DROP TABLE _extremities_to_check") + + yield self.db_pool.runInteraction( + "_cleanup_extremities_bg_update_drop_table", _drop_table_txn + ) + + return num_handled + + @defer.inlineCallbacks + def _redactions_received_ts(self, progress, batch_size): + """Handles filling out the `received_ts` column in redactions. + """ + last_event_id = progress.get("last_event_id", "") + + def _redactions_received_ts_txn(txn): + # Fetch the set of event IDs that we want to update + sql = """ + SELECT event_id FROM redactions + WHERE event_id > ? + ORDER BY event_id ASC + LIMIT ? + """ + + txn.execute(sql, (last_event_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + (upper_event_id,) = rows[-1] + + # Update the redactions with the received_ts. + # + # Note: Not all events have an associated received_ts, so we + # fallback to using origin_server_ts. If we for some reason don't + # have an origin_server_ts, lets just use the current timestamp. + # + # We don't want to leave it null, as then we'll never try and + # censor those redactions. + sql = """ + UPDATE redactions + SET received_ts = ( + SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events + WHERE events.event_id = redactions.event_id + ) + WHERE ? <= event_id AND event_id <= ? + """ + + txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) + + self.db_pool.updates._background_update_progress_txn( + txn, "redactions_received_ts", {"last_event_id": upper_event_id} + ) + + return len(rows) + + count = yield self.db_pool.runInteraction( + "_redactions_received_ts", _redactions_received_ts_txn + ) + + if not count: + yield self.db_pool.updates._end_background_update("redactions_received_ts") + + return count + + @defer.inlineCallbacks + def _event_fix_redactions_bytes(self, progress, batch_size): + """Undoes hex encoded censored redacted event JSON. + """ + + def _event_fix_redactions_bytes_txn(txn): + # This update is quite fast due to new index. + txn.execute( + """ + UPDATE event_json + SET + json = convert_from(json::bytea, 'utf8') + FROM redactions + WHERE + redactions.have_censored + AND event_json.event_id = redactions.redacts + AND json NOT LIKE '{%'; + """ + ) + + txn.execute("DROP INDEX redactions_censored_redacts") + + yield self.db_pool.runInteraction( + "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn + ) + + yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes") + + return 1 + + @defer.inlineCallbacks + def _event_store_labels(self, progress, batch_size): + """Background update handler which will store labels for existing events.""" + last_event_id = progress.get("last_event_id", "") + + def _event_store_labels_txn(txn): + txn.execute( + """ + SELECT event_id, json FROM event_json + LEFT JOIN event_labels USING (event_id) + WHERE event_id > ? AND label IS NULL + ORDER BY event_id LIMIT ? + """, + (last_event_id, batch_size), + ) + + results = list(txn) + + nbrows = 0 + last_row_event_id = "" + for (event_id, event_json_raw) in results: + try: + event_json = db_to_json(event_json_raw) + + self.db_pool.simple_insert_many_txn( + txn=txn, + table="event_labels", + values=[ + { + "event_id": event_id, + "label": label, + "room_id": event_json["room_id"], + "topological_ordering": event_json["depth"], + } + for label in event_json["content"].get( + EventContentFields.LABELS, [] + ) + if isinstance(label, str) + ], + ) + except Exception as e: + logger.warning( + "Unable to load event %s (no labels will be imported): %s", + event_id, + e, + ) + + nbrows += 1 + last_row_event_id = event_id + + self.db_pool.updates._background_update_progress_txn( + txn, "event_store_labels", {"last_event_id": last_row_event_id} + ) + + return nbrows + + num_rows = yield self.db_pool.runInteraction( + desc="event_store_labels", func=_event_store_labels_txn + ) + + if not num_rows: + yield self.db_pool.updates._end_background_update("event_store_labels") + + return num_rows diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py new file mode 100644 index 0000000000..a7b7393f6e --- /dev/null +++ b/synapse/storage/databases/main/events_worker.py @@ -0,0 +1,1454 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import itertools +import logging +import threading +from collections import namedtuple +from typing import List, Optional, Tuple + +from constantly import NamedConstant, Names + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import NotFoundError, SynapseError +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersions, +) +from synapse.events import make_event_from_dict +from synapse.events.utils import prune_event +from synapse.logging.context import PreserveLoggingContext, current_context +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import BackfillStream +from synapse.replication.tcp.streams.events import EventsStream +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import get_domain_from_id +from synapse.util.caches.descriptors import ( + Cache, + _CacheContext, + cached, + cachedInlineCallbacks, +) +from synapse.util.iterutils import batch_iter +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +class EventRedactBehaviour(Names): + """ + What to do when retrieving a redacted event from the database. + """ + + AS_IS = NamedConstant() + REDACT = NamedConstant() + BLOCK = NamedConstant() + + +class EventsWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(EventsWorkerStore, self).__init__(database, db_conn, hs) + + if hs.config.worker.writers.events == hs.get_instance_name(): + # We are the process in charge of generating stream ids for events, + # so instantiate ID generators based on the database + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + # Another process is in charge of persisting events and generating + # stream IDs: rely on the replication streams to let us know which + # IDs we can process. + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + + self._get_event_cache = Cache( + "*getEvent*", + keylen=3, + max_entries=hs.config.caches.event_cache_size, + apply_cache_factor_from_config=False, + ) + + self._event_fetch_lock = threading.Condition() + self._event_fetch_list = [] + self._event_fetch_ongoing = 0 + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(token) + elif stream_name == BackfillStream.NAME: + self._backfill_id_gen.advance(-token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + + def get_received_ts(self, event_id): + """Get received_ts (when it was persisted) for the event. + + Raises an exception for unknown events. + + Args: + event_id (str) + + Returns: + Deferred[int|None]: Timestamp in milliseconds, or None for events + that were persisted before received_ts was implemented. + """ + return self.db_pool.simple_select_one_onecol( + table="events", + keyvalues={"event_id": event_id}, + retcol="received_ts", + desc="get_received_ts", + ) + + def get_received_ts_by_stream_pos(self, stream_ordering): + """Given a stream ordering get an approximate timestamp of when it + happened. + + This is done by simply taking the received ts of the first event that + has a stream ordering greater than or equal to the given stream pos. + If none exists returns the current time, on the assumption that it must + have happened recently. + + Args: + stream_ordering (int) + + Returns: + Deferred[int] + """ + + def _get_approximate_received_ts_txn(txn): + sql = """ + SELECT received_ts FROM events + WHERE stream_ordering >= ? + LIMIT 1 + """ + + txn.execute(sql, (stream_ordering,)) + row = txn.fetchone() + if row and row[0]: + ts = row[0] + else: + ts = self.clock.time_msec() + + return ts + + return self.db_pool.runInteraction( + "get_approximate_received_ts", _get_approximate_received_ts_txn + ) + + @defer.inlineCallbacks + def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: bool = False, + check_room_id: Optional[str] = None, + ): + """Get an event from the database by event_id. + + Args: + event_id: The event_id of the event to fetch + + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events (behave as per allow_none + if the event is redacted) + + get_prev_content: If True and event is a state event, + include the previous states content in the unsigned field. + + allow_rejected: If True, return rejected events. Otherwise, + behave as per allow_none. + + allow_none: If True, return None if no event found, if + False throw a NotFoundError + + check_room_id: if not None, check the room of the found event. + If there is a mismatch, behave as per allow_none. + + Returns: + Deferred[EventBase|None] + """ + if not isinstance(event_id, str): + raise TypeError("Invalid event event_id %r" % (event_id,)) + + events = yield self.get_events_as_list( + [event_id], + redact_behaviour=redact_behaviour, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event = events[0] if events else None + + if event is not None and check_room_id is not None: + if event.room_id != check_room_id: + event = None + + if event is None and not allow_none: + raise NotFoundError("Could not find event %s" % (event_id,)) + + return event + + @defer.inlineCallbacks + def get_events( + self, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + ): + """Get events from the database + + Args: + event_ids: The event_ids of the events to fetch + + redact_behaviour: Determine what to do with a redacted event. Possible + values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events (omit them from the response) + + get_prev_content: If True and event is a state event, + include the previous states content in the unsigned field. + + allow_rejected: If True, return rejected events. Otherwise, + omits rejeted events from the response. + + Returns: + Deferred : Dict from event_id to event. + """ + events = yield self.get_events_as_list( + event_ids, + redact_behaviour=redact_behaviour, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + return {e.event_id: e for e in events} + + @defer.inlineCallbacks + def get_events_as_list( + self, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + ): + """Get events from the database and return in a list in the same order + as given by `event_ids` arg. + + Unknown events will be omitted from the response. + + Args: + event_ids: The event_ids of the events to fetch + + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events (omit them from the response) + + get_prev_content: If True and event is a state event, + include the previous states content in the unsigned field. + + allow_rejected: If True, return rejected events. Otherwise, + omits rejected events from the response. + + Returns: + Deferred[list[EventBase]]: List of events fetched from the database. The + events are in the same order as `event_ids` arg. + + Note that the returned list may be smaller than the list of event + IDs if not all events could be fetched. + """ + + if not event_ids: + return [] + + # there may be duplicates so we cast the list to a set + event_entry_map = yield self._get_events_from_cache_or_db( + set(event_ids), allow_rejected=allow_rejected + ) + + events = [] + for event_id in event_ids: + entry = event_entry_map.get(event_id, None) + if not entry: + continue + + if not allow_rejected: + assert not entry.event.rejected_reason, ( + "rejected event returned from _get_events_from_cache_or_db despite " + "allow_rejected=False" + ) + + # We may not have had the original event when we received a redaction, so + # we have to recheck auth now. + + if not allow_rejected and entry.event.type == EventTypes.Redaction: + if entry.event.redacts is None: + # A redacted redaction doesn't have a `redacts` key, in + # which case lets just withhold the event. + # + # Note: Most of the time if the redactions has been + # redacted we still have the un-redacted event in the DB + # and so we'll still see the `redacts` key. However, this + # isn't always true e.g. if we have censored the event. + logger.debug( + "Withholding redaction event %s as we don't have redacts key", + event_id, + ) + continue + + redacted_event_id = entry.event.redacts + event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + original_event_entry = event_map.get(redacted_event_id) + if not original_event_entry: + # we don't have the redacted event (or it was rejected). + # + # We assume that the redaction isn't authorized for now; if the + # redacted event later turns up, the redaction will be re-checked, + # and if it is found valid, the original will get redacted before it + # is served to the client. + logger.debug( + "Withholding redaction event %s since we don't (yet) have the " + "original %s", + event_id, + redacted_event_id, + ) + continue + + original_event = original_event_entry.event + if original_event.type == EventTypes.Create: + # we never serve redactions of Creates to clients. + logger.info( + "Withholding redaction %s of create event %s", + event_id, + redacted_event_id, + ) + continue + + if original_event.room_id != entry.event.room_id: + logger.info( + "Withholding redaction %s of event %s from a different room", + event_id, + redacted_event_id, + ) + continue + + if entry.event.internal_metadata.need_to_check_redaction(): + original_domain = get_domain_from_id(original_event.sender) + redaction_domain = get_domain_from_id(entry.event.sender) + if original_domain != redaction_domain: + # the senders don't match, so this is forbidden + logger.info( + "Withholding redaction %s whose sender domain %s doesn't " + "match that of redacted event %s %s", + event_id, + redaction_domain, + redacted_event_id, + original_domain, + ) + continue + + # Update the cache to save doing the checks again. + entry.event.internal_metadata.recheck_redaction = False + + event = entry.event + + if entry.redacted_event: + if redact_behaviour == EventRedactBehaviour.BLOCK: + # Skip this event + continue + elif redact_behaviour == EventRedactBehaviour.REDACT: + event = entry.redacted_event + + events.append(event) + + if get_prev_content: + if "replaces_state" in event.unsigned: + prev = yield self.get_event( + event.unsigned["replaces_state"], + get_prev_content=False, + allow_none=True, + ) + if prev: + event.unsigned = dict(event.unsigned) + event.unsigned["prev_content"] = prev.content + event.unsigned["prev_sender"] = prev.sender + + return events + + @defer.inlineCallbacks + def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the cache or the database. + + If events are pulled from the database, they will be cached for future lookups. + + Unknown events are omitted from the response. + + Args: + + event_ids (Iterable[str]): The event_ids of the events to fetch + + allow_rejected (bool): Whether to include rejected events. If False, + rejected events are omitted from the response. + + Returns: + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result + """ + event_entry_map = self._get_events_from_cache( + event_ids, allow_rejected=allow_rejected + ) + + missing_events_ids = [e for e in event_ids if e not in event_entry_map] + + if missing_events_ids: + log_ctx = current_context() + log_ctx.record_event_fetch(len(missing_events_ids)) + + # Note that _get_events_from_db is also responsible for turning db rows + # into FrozenEvents (via _get_event_from_row), which involves seeing if + # the events have been redacted, and if so pulling the redaction event out + # of the database to check it. + # + missing_events = yield self._get_events_from_db( + missing_events_ids, allow_rejected=allow_rejected + ) + + event_entry_map.update(missing_events) + + return event_entry_map + + def _invalidate_get_event_cache(self, event_id): + self._get_event_cache.invalidate((event_id,)) + + def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): + """Fetch events from the caches + + Args: + events (Iterable[str]): list of event_ids to fetch + allow_rejected (bool): Whether to return events that were rejected + update_metrics (bool): Whether to update the cache hit ratio metrics + + Returns: + dict of event_id -> _EventCacheEntry for each event_id in cache. If + allow_rejected is `False` then there will still be an entry but it + will be `None` + """ + event_map = {} + + for event_id in events: + ret = self._get_event_cache.get( + (event_id,), None, update_metrics=update_metrics + ) + if not ret: + continue + + if allow_rejected or not ret.event.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + self._fetch_event_list(conn, event_list) + + def _fetch_event_list(self, conn, event_list): + """Handle a load of requests from the _event_fetch_list queue + + Args: + conn (twisted.enterprise.adbapi.Connection): database connection + + event_list (list[Tuple[list[str], Deferred]]): + The fetch requests. Each entry consists of a list of event + ids to be fetched, and a deferred to be completed once the + events have been fetched. + + The deferreds are callbacked with a dictionary mapping from event id + to event row. Note that it may well contain additional events that + were not part of this request. + """ + with Measure(self._clock, "_fetch_event_list"): + try: + events_to_fetch = { + event_id for events, _ in event_list for event_id in events + } + + row_dict = self.db_pool.new_transaction( + conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch + ) + + # We only want to resolve deferreds from the main thread + def fire(): + for _, d in event_list: + d.callback(row_dict) + + with PreserveLoggingContext(): + self.hs.get_reactor().callFromThread(fire) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs, exc): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(exc) + + with PreserveLoggingContext(): + self.hs.get_reactor().callFromThread(fire, event_list, e) + + @defer.inlineCallbacks + def _get_events_from_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the database. + + Returned events will be added to the cache for future lookups. + + Unknown events are omitted from the response. + + Args: + event_ids (Iterable[str]): The event_ids of the events to fetch + + allow_rejected (bool): Whether to include rejected events. If False, + rejected events are omitted from the response. + + Returns: + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result. May return extra events which + weren't asked for. + """ + fetched_events = {} + events_to_fetch = event_ids + + while events_to_fetch: + row_map = yield self._enqueue_events(events_to_fetch) + + # we need to recursively fetch any redactions of those events + redaction_ids = set() + for event_id in events_to_fetch: + row = row_map.get(event_id) + fetched_events[event_id] = row + if row: + redaction_ids.update(row["redactions"]) + + events_to_fetch = redaction_ids.difference(fetched_events.keys()) + if events_to_fetch: + logger.debug("Also fetching redaction events %s", events_to_fetch) + + # build a map from event_id to EventBase + event_map = {} + for event_id, row in fetched_events.items(): + if not row: + continue + assert row["event_id"] == event_id + + rejected_reason = row["rejected_reason"] + + if not allow_rejected and rejected_reason: + continue + + d = db_to_json(row["json"]) + internal_metadata = db_to_json(row["internal_metadata"]) + + format_version = row["format_version"] + if format_version is None: + # This means that we stored the event before we had the concept + # of a event format version, so it must be a V1 event. + format_version = EventFormatVersions.V1 + + room_version_id = row["room_version_id"] + + if not room_version_id: + # this should only happen for out-of-band membership events + if not internal_metadata.get("out_of_band_membership"): + logger.warning( + "Room %s for event %s is unknown", d["room_id"], event_id + ) + continue + + # take a wild stab at the room version based on the event format + if format_version == EventFormatVersions.V1: + room_version = RoomVersions.V1 + elif format_version == EventFormatVersions.V2: + room_version = RoomVersions.V3 + else: + room_version = RoomVersions.V5 + else: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: + logger.warning( + "Event %s in room %s has unknown room version %s", + event_id, + d["room_id"], + room_version_id, + ) + continue + + if room_version.event_format != format_version: + logger.error( + "Event %s in room %s with version %s has wrong format: " + "expected %s, was %s", + event_id, + d["room_id"], + room_version_id, + room_version.event_format, + format_version, + ) + continue + + original_ev = make_event_from_dict( + event_dict=d, + room_version=room_version, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + event_map[event_id] = original_ev + + # finally, we can decide whether each one needs redacting, and build + # the cache entries. + result_map = {} + for event_id, original_ev in event_map.items(): + redactions = fetched_events[event_id]["redactions"] + redacted_event = self._maybe_redact_event_row( + original_ev, redactions, event_map + ) + + cache_entry = _EventCacheEntry( + event=original_ev, redacted_event=redacted_event + ) + + self._get_event_cache.prefill((event_id,), cache_entry) + result_map[event_id] = cache_entry + + return result_map + + @defer.inlineCallbacks + def _enqueue_events(self, events): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + + Args: + events (Iterable[str]): events to be fetched. + + Returns: + Deferred[Dict[str, Dict]]: map from event id to row data from the database. + May contain events that weren't requested. + """ + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append((events, events_d)) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process( + "fetch_events", self.db_pool.runWithConnection, self._do_fetch + ) + + logger.debug("Loading %d events: %s", len(events), events) + with PreserveLoggingContext(): + row_map = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) + + return row_map + + def _fetch_event_rows(self, txn, event_ids): + """Fetch event rows from the database + + Events which are not found are omitted from the result. + + The returned per-event dicts contain the following keys: + + * event_id (str) + + * json (str): json-encoded event structure + + * internal_metadata (str): json-encoded internal metadata dict + + * format_version (int|None): The format of the event. Hopefully one + of EventFormatVersions. 'None' means the event predates + EventFormatVersions (so the event is format V1). + + * room_version_id (str|None): The version of the room which contains the event. + Hopefully one of RoomVersions. + + Due to historical reasons, there may be a few events in the database which + do not have an associated room; in this case None will be returned here. + + * rejected_reason (str|None): if the event was rejected, the reason + why. + + * redactions (List[str]): a list of event-ids which (claim to) redact + this event. + + Args: + txn (twisted.enterprise.adbapi.Connection): + event_ids (Iterable[str]): event IDs to fetch + + Returns: + Dict[str, Dict]: a map from event id to event info. + """ + event_dict = {} + for evs in batch_iter(event_ids, 200): + sql = """\ + SELECT + e.event_id, + e.internal_metadata, + e.json, + e.format_version, + r.room_version, + rej.reason + FROM event_json as e + LEFT JOIN rooms r USING (room_id) + LEFT JOIN rejections as rej USING (event_id) + WHERE """ + + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", evs + ) + + txn.execute(sql + clause, args) + + for row in txn: + event_id = row[0] + event_dict[event_id] = { + "event_id": event_id, + "internal_metadata": row[1], + "json": row[2], + "format_version": row[3], + "room_version_id": row[4], + "rejected_reason": row[5], + "redactions": [], + } + + # check for redactions + redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " + + clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs) + + txn.execute(redactions_sql + clause, args) + + for (redacter, redacted) in txn: + d = event_dict.get(redacted) + if d: + d["redactions"].append(redacter) + + return event_dict + + def _maybe_redact_event_row(self, original_ev, redactions, event_map): + """Given an event object and a list of possible redacting event ids, + determine whether to honour any of those redactions and if so return a redacted + event. + + Args: + original_ev (EventBase): + redactions (iterable[str]): list of event ids of potential redaction events + event_map (dict[str, EventBase]): other events which have been fetched, in + which we can look up the redaaction events. Map from event id to event. + + Returns: + Deferred[EventBase|None]: if the event should be redacted, a pruned + event object. Otherwise, None. + """ + if original_ev.type == "m.room.create": + # we choose to ignore redactions of m.room.create events. + return None + + for redaction_id in redactions: + redaction_event = event_map.get(redaction_id) + if not redaction_event or redaction_event.rejected_reason: + # we don't have the redaction event, or the redaction event was not + # authorized. + logger.debug( + "%s was redacted by %s but redaction not found/authed", + original_ev.event_id, + redaction_id, + ) + continue + + if redaction_event.room_id != original_ev.room_id: + logger.debug( + "%s was redacted by %s but redaction was in a different room!", + original_ev.event_id, + redaction_id, + ) + continue + + # Starting in room version v3, some redactions need to be + # rechecked if we didn't have the redacted event at the + # time, so we recheck on read instead. + if redaction_event.internal_metadata.need_to_check_redaction(): + expected_domain = get_domain_from_id(original_ev.sender) + if get_domain_from_id(redaction_event.sender) == expected_domain: + # This redaction event is allowed. Mark as not needing a recheck. + redaction_event.internal_metadata.recheck_redaction = False + else: + # Senders don't match, so the event isn't actually redacted + logger.debug( + "%s was redacted by %s but the senders don't match", + original_ev.event_id, + redaction_id, + ) + continue + + logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id) + + # we found a good redaction event. Redact! + redacted_event = prune_event(original_ev) + redacted_event.unsigned["redacted_by"] = redaction_id + + # It's fine to add the event directly, since get_pdu_json + # will serialise this field correctly + redacted_event.unsigned["redacted_because"] = redaction_event + + return redacted_event + + # no valid redaction found for this event + return None + + @defer.inlineCallbacks + def have_events_in_timeline(self, event_ids): + """Given a list of event ids, check if we have already processed and + stored them as non outliers. + """ + rows = yield self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) + + return {r["event_id"] for r in rows} + + @defer.inlineCallbacks + def have_seen_events(self, event_ids): + """Given a list of event ids, check if we have already processed them. + + Args: + event_ids (iterable[str]): + + Returns: + Deferred[set[str]]: The events we have already seen. + """ + results = set() + + def have_seen_events_txn(txn, chunk): + sql = "SELECT event_id FROM events as e WHERE " + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", chunk + ) + txn.execute(sql + clause, args) + for (event_id,) in txn: + results.add(event_id) + + # break the input up into chunks of 100 + input_iterator = iter(event_ids) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): + yield self.db_pool.runInteraction( + "have_seen_events", have_seen_events_txn, chunk + ) + return results + + def _get_total_state_event_counts_txn(self, txn, room_id): + """ + See get_total_state_event_counts. + """ + # We join against the events table as that has an index on room_id + sql = """ + SELECT COUNT(*) FROM state_events + INNER JOIN events USING (room_id, event_id) + WHERE room_id=? + """ + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_total_state_event_counts(self, room_id): + """ + Gets the total number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.db_pool.runInteraction( + "get_total_state_event_counts", + self._get_total_state_event_counts_txn, + room_id, + ) + + def _get_current_state_event_counts_txn(self, txn, room_id): + """ + See get_current_state_event_counts. + """ + sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?" + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_current_state_event_counts(self, room_id): + """ + Gets the current number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.db_pool.runInteraction( + "get_current_state_event_counts", + self._get_current_state_event_counts_txn, + room_id, + ) + + @defer.inlineCallbacks + def get_room_complexity(self, room_id): + """ + Get a rough approximation of the complexity of the room. This is used by + remote servers to decide whether they wish to join the room or not. + Higher complexity value indicates that being in the room will consume + more resources. + + Args: + room_id (str) + + Returns: + Deferred[dict[str:int]] of complexity version to complexity. + """ + state_events = yield self.get_current_state_event_counts(room_id) + + # Call this one "v1", so we can introduce new ones as we want to develop + # it. + complexity_v1 = round(state_events / 500, 2) + + return {"v1": complexity_v1} + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + return -self._backfill_id_gen.get_current_token() + + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + """Returns new events, for the Events replication stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: Deferred[List[Tuple]] + a list of events stream rows. Each tuple consists of a stream id as + the first element, followed by fields suitable for casting into an + EventsStreamRow. + """ + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return self.db_pool.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_ex_outlier_stream_rows(self, last_id, current_id): + """Returns de-outliered events, for the Events replication stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + + Returns: Deferred[List[Tuple]] + a list of events stream rows. Each tuple consists of a stream id as + the first element, followed by fields suitable for casting into an + EventsStreamRow. + """ + + def get_ex_outlier_stream_rows_txn(txn): + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering ASC" + ) + + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() + + return self.db_pool.runInteraction( + "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn + ) + + async def get_all_new_backfill_event_rows( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for backfill replication stream, including all new + backfilled events and events that have gone from being outliers to not. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: + return [], current_id, False + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = [(row[0], row[1:]) for row in txn] + + limited = False + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + limited = True + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend((row[0], row[1:]) for row in txn) + + if len(new_event_updates) >= limit: + upper_bound = new_event_updates[-1][0] + limited = True + + return new_event_updates, upper_bound, limited + + return await self.db_pool.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + + async def get_all_updated_current_state_deltas( + self, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple], int, bool]: + """Fetch updates from current_state_delta_stream + + Args: + from_token: The previous stream token. Updates from this stream id will + be excluded. + + to_token: The current stream token (ie the upper limit). Updates up to this + stream id will be included (modulo the 'limit' param) + + target_row_count: The number of rows to try to return. If more rows are + available, we will set 'limited' in the result. In the event of a large + batch, we may return more rows than this. + Returns: + A triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of database tuples. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. + """ + + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, target_row_count)) + return txn.fetchall() + + def get_deltas_for_stream_id_txn(txn, stream_id): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE stream_id = ? + """ + txn.execute(sql, [stream_id]) + return txn.fetchall() + + # we need to make sure that, for every stream id in the results, we get *all* + # the rows with that stream id. + + rows = await self.db_pool.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, + ) # type: List[Tuple] + + # if we've got fewer rows than the limit, we're good + if len(rows) < target_row_count: + return rows, to_token, False + + # we hit the limit, so reduce the upper limit so that we exclude the stream id + # of the last row in the result. + assert rows[-1][0] <= to_token + to_token = rows[-1][0] - 1 + + # search backwards through the list for the point to truncate + for idx in range(len(rows) - 1, 0, -1): + if rows[idx - 1][0] <= to_token: + return rows[:idx], to_token, True + + # bother. We didn't get a full set of changes for even a single + # stream id. let's run the query again, without a row limit, but for + # just one stream id. + to_token += 1 + rows = await self.db_pool.runInteraction( + "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token + ) + + return rows, to_token, True + + @cached(num_args=5, max_entries=10) + def get_all_new_events( + self, + last_backfill_id, + last_forward_id, + current_backfill_id, + current_forward_id, + limit, + ): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + have_backfill_events = last_backfill_id != current_backfill_id + have_forward_events = last_forward_id != current_forward_id + + if not have_backfill_events and not have_forward_events: + return defer.succeed(AllNewEventsResult([], [], [], [], [])) + + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + if have_forward_events: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + + if len(new_forward_events) == limit: + upper_bound = new_forward_events[-1][0] + else: + upper_bound = current_forward_id + + sql = ( + "SELECT event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + forward_ex_outliers = txn.fetchall() + else: + new_forward_events = [] + forward_ex_outliers = [] + + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + if have_backfill_events: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + + if len(new_backfill_events) == limit: + upper_bound = new_backfill_events[-1][0] + else: + upper_bound = current_backfill_id + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_backfill_id, -upper_bound)) + backward_ex_outliers = txn.fetchall() + else: + new_backfill_events = [] + backward_ex_outliers = [] + + return AllNewEventsResult( + new_forward_events, + new_backfill_events, + forward_ex_outliers, + backward_ex_outliers, + ) + + return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn) + + async def is_event_after(self, event_id1, event_id2): + """Returns True if event_id1 is after event_id2 in the stream + """ + to_1, so_1 = await self.get_event_ordering(event_id1) + to_2, so_2 = await self.get_event_ordering(event_id2) + return (to_1, so_1) > (to_2, so_2) + + @cachedInlineCallbacks(max_entries=5000) + def get_event_ordering(self, event_id): + res = yield self.db_pool.simple_select_one( + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True, + ) + + if not res: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + return (int(res["topological_ordering"]), int(res["stream_ordering"])) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.db_pool.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + + @cached(tree=True, cache_context=True) + async def get_unread_message_count_for_user( + self, room_id: str, user_id: str, cache_context: _CacheContext, + ) -> int: + """Retrieve the count of unread messages for the given room and user. + + Args: + room_id: The ID of the room to count unread messages in. + user_id: The ID of the user to count unread messages for. + + Returns: + The number of unread messages for the given user in the given room. + """ + with Measure(self._clock, "get_unread_message_count_for_user"): + last_read_event_id = await self.get_last_receipt_event_id_for_user( + user_id=user_id, + room_id=room_id, + receipt_type="m.read", + on_invalidate=cache_context.invalidate, + ) + + return await self.db_pool.runInteraction( + "get_unread_message_count_for_user", + self._get_unread_message_count_for_user_txn, + user_id, + room_id, + last_read_event_id, + ) + + def _get_unread_message_count_for_user_txn( + self, + txn: Cursor, + user_id: str, + room_id: str, + last_read_event_id: Optional[str], + ) -> int: + if last_read_event_id: + # Get the stream ordering for the last read event. + stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="events", + keyvalues={"room_id": room_id, "event_id": last_read_event_id}, + retcol="stream_ordering", + ) + else: + # If there's no read receipt for that room, it probably means the user hasn't + # opened it yet, in which case use the stream ID of their join event. + # We can't just set it to 0 otherwise messages from other local users from + # before this user joined will be counted as well. + txn.execute( + """ + SELECT stream_ordering FROM local_current_membership + LEFT JOIN events USING (event_id, room_id) + WHERE membership = 'join' + AND user_id = ? + AND room_id = ? + """, + (user_id, room_id), + ) + row = txn.fetchone() + + if row is None: + return 0 + + stream_ordering = row[0] + + # Count the messages that qualify as unread after the stream ordering we've just + # retrieved. + sql = """ + SELECT COUNT(*) FROM events + WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread + """ + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + + return row[0] if row else 0 + + +AllNewEventsResult = namedtuple( + "AllNewEventsResult", + [ + "new_forward_events", + "new_backfill_events", + "forward_ex_outliers", + "backward_ex_outliers", + ], +) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py new file mode 100644 index 0000000000..cae6bda80e --- /dev/null +++ b/synapse/storage/databases/main/filtering.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from canonicaljson import encode_canonical_json + +from synapse.api.errors import Codes, SynapseError +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util.caches.descriptors import cachedInlineCallbacks + + +class FilteringStore(SQLBaseStore): + @cachedInlineCallbacks(num_args=2) + def get_user_filter(self, user_localpart, filter_id): + # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail + # with a coherent error message rather than 500 M_UNKNOWN. + try: + int(filter_id) + except ValueError: + raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) + + def_json = yield self.db_pool.simple_select_one_onecol( + table="user_filters", + keyvalues={"user_id": user_localpart, "filter_id": filter_id}, + retcol="filter_json", + allow_none=False, + desc="get_user_filter", + ) + + return db_to_json(def_json) + + def add_user_filter(self, user_localpart, user_filter): + def_json = encode_canonical_json(user_filter) + + # Need an atomic transaction to SELECT the maximal ID so far then + # INSERT a new one + def _do_txn(txn): + sql = ( + "SELECT filter_id FROM user_filters " + "WHERE user_id = ? AND filter_json = ?" + ) + txn.execute(sql, (user_localpart, bytearray(def_json))) + filter_id_response = txn.fetchone() + if filter_id_response is not None: + return filter_id_response[0] + + sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" + txn.execute(sql, (user_localpart,)) + max_id = txn.fetchone()[0] + if max_id is None: + filter_id = 0 + else: + filter_id = max_id + 1 + + sql = ( + "INSERT INTO user_filters (user_id, filter_id, filter_json)" + "VALUES(?, ?, ?)" + ) + txn.execute(sql, (user_localpart, filter_id, bytearray(def_json))) + + return filter_id + + return self.db_pool.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py new file mode 100644 index 0000000000..a98181f445 --- /dev/null +++ b/synapse/storage/databases/main/group_server.py @@ -0,0 +1,1297 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore, db_to_json + +# The category ID for the "default" category. We don't store as null in the +# database to avoid the fun of null != null +_DEFAULT_CATEGORY_ID = "" +_DEFAULT_ROLE_ID = "" + + +class GroupServerWorkerStore(SQLBaseStore): + def get_group(self, group_id): + return self.db_pool.simple_select_one( + table="groups", + keyvalues={"group_id": group_id}, + retcols=( + "name", + "short_description", + "long_description", + "avatar_url", + "is_public", + "join_policy", + ), + allow_none=True, + desc="get_group", + ) + + def get_users_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + return self.db_pool.simple_select_list( + table="group_users", + keyvalues=keyvalues, + retcols=("user_id", "is_public", "is_admin"), + desc="get_users_in_group", + ) + + def get_invited_users_in_group(self, group_id): + # TODO: Pagination + + return self.db_pool.simple_select_onecol( + table="group_invites", + keyvalues={"group_id": group_id}, + retcol="user_id", + desc="get_invited_users_in_group", + ) + + def get_rooms_in_group(self, group_id: str, include_private: bool = False): + """Retrieve the rooms that belong to a given group. Does not return rooms that + lack members. + + Args: + group_id: The ID of the group to query for rooms + include_private: Whether to return private rooms in results + + Returns: + Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the + form of: + + { + "room_id": "!a_room_id:example.com", # The ID of the room + "is_public": False # Whether this is a public room or not + } + """ + # TODO: Pagination + + def _get_rooms_in_group_txn(txn): + sql = """ + SELECT room_id, is_public FROM group_rooms + WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) + """ + args = [group_id, False] + + if not include_private: + sql += " AND is_public = ?" + args += [True] + + txn.execute(sql, args) + + return [ + {"room_id": room_id, "is_public": is_public} + for room_id, is_public in txn + ] + + return self.db_pool.runInteraction( + "get_rooms_in_group", _get_rooms_in_group_txn + ) + + def get_rooms_for_summary_by_category( + self, group_id: str, include_private: bool = False, + ): + """Get the rooms and categories that should be included in a summary request + + Args: + group_id: The ID of the group to query the summary for + include_private: Whether to return private rooms in results + + Returns: + Deferred[Tuple[List, Dict]]: A tuple containing: + + * A list of dictionaries with the keys: + * "room_id": str, the room ID + * "is_public": bool, whether the room is public + * "category_id": str|None, the category ID if set, else None + * "order": int, the sort order of rooms + + * A dictionary with the key: + * category_id (str): a dictionary with the keys: + * "is_public": bool, whether the category is public + * "profile": str, the category profile + * "order": int, the sort order of rooms in this category + """ + + def _get_rooms_for_summary_txn(txn): + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT room_id, is_public, category_id, room_order + FROM group_summary_rooms + WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, False, True)) + else: + txn.execute(sql, (group_id, False)) + + rooms = [ + { + "room_id": row[0], + "is_public": row[1], + "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT category_id, is_public, profile, cat_order + FROM group_summary_room_categories + INNER JOIN group_room_categories USING (group_id, category_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + categories = { + row[0]: { + "is_public": row[1], + "profile": db_to_json(row[2]), + "order": row[3], + } + for row in txn + } + + return rooms, categories + + return self.db_pool.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) + + @defer.inlineCallbacks + def get_group_categories(self, group_id): + rows = yield self.db_pool.simple_select_list( + table="group_room_categories", + keyvalues={"group_id": group_id}, + retcols=("category_id", "is_public", "profile"), + desc="get_group_categories", + ) + + return { + row["category_id"]: { + "is_public": row["is_public"], + "profile": db_to_json(row["profile"]), + } + for row in rows + } + + @defer.inlineCallbacks + def get_group_category(self, group_id, category_id): + category = yield self.db_pool.simple_select_one( + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + retcols=("is_public", "profile"), + desc="get_group_category", + ) + + category["profile"] = db_to_json(category["profile"]) + + return category + + @defer.inlineCallbacks + def get_group_roles(self, group_id): + rows = yield self.db_pool.simple_select_list( + table="group_roles", + keyvalues={"group_id": group_id}, + retcols=("role_id", "is_public", "profile"), + desc="get_group_roles", + ) + + return { + row["role_id"]: { + "is_public": row["is_public"], + "profile": db_to_json(row["profile"]), + } + for row in rows + } + + @defer.inlineCallbacks + def get_group_role(self, group_id, role_id): + role = yield self.db_pool.simple_select_one( + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + retcols=("is_public", "profile"), + desc="get_group_role", + ) + + role["profile"] = db_to_json(role["profile"]) + + return role + + def get_local_groups_for_room(self, room_id): + """Get all of the local group that contain a given room + Args: + room_id (str): The ID of a room + Returns: + Deferred[list[str]]: A twisted.Deferred containing a list of group ids + containing this room + """ + return self.db_pool.simple_select_onecol( + table="group_rooms", + keyvalues={"room_id": room_id}, + retcol="group_id", + desc="get_local_groups_for_room", + ) + + def get_users_for_summary_by_role(self, group_id, include_private=False): + """Get the users and roles that should be included in a summary request + + Returns ([users], [roles]) + """ + + def _get_users_for_summary_txn(txn): + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT user_id, is_public, role_id, user_order + FROM group_summary_users + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + users = [ + { + "user_id": row[0], + "is_public": row[1], + "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT role_id, is_public, profile, role_order + FROM group_summary_roles + INNER JOIN group_roles USING (group_id, role_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + roles = { + row[0]: { + "is_public": row[1], + "profile": db_to_json(row[2]), + "order": row[3], + } + for row in txn + } + + return users, roles + + return self.db_pool.runInteraction( + "get_users_for_summary_by_role", _get_users_for_summary_txn + ) + + def is_user_in_group(self, user_id, group_id): + return self.db_pool.simple_select_one_onecol( + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="is_user_in_group", + ).addCallback(lambda r: bool(r)) + + def is_user_admin_in_group(self, group_id, user_id): + return self.db_pool.simple_select_one_onecol( + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="is_admin", + allow_none=True, + desc="is_user_admin_in_group", + ) + + def is_user_invited_to_local_group(self, group_id, user_id): + """Has the group server invited a user? + """ + return self.db_pool.simple_select_one_onecol( + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + desc="is_user_invited_to_local_group", + allow_none=True, + ) + + def get_users_membership_info_in_group(self, group_id, user_id): + """Get a dict describing the membership of a user in a group. + + Example if joined: + + { + "membership": "join", + "is_public": True, + "is_privileged": False, + } + + Returns an empty dict if the user is not join/invite/etc + """ + + def _get_users_membership_in_group_txn(txn): + row = self.db_pool.simple_select_one_txn( + txn, + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcols=("is_admin", "is_public"), + allow_none=True, + ) + + if row: + return { + "membership": "join", + "is_public": row["is_public"], + "is_privileged": row["is_admin"], + } + + row = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + ) + + if row: + return {"membership": "invite"} + + return {} + + return self.db_pool.runInteraction( + "get_users_membership_info_in_group", _get_users_membership_in_group_txn + ) + + def get_publicised_groups_for_user(self, user_id): + """Get all groups a user is publicising + """ + return self.db_pool.simple_select_onecol( + table="local_group_membership", + keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, + retcol="group_id", + desc="get_publicised_groups_for_user", + ) + + def get_attestations_need_renewals(self, valid_until_ms): + """Get all attestations that need to be renewed until givent time + """ + + def _get_attestations_need_renewals_txn(txn): + sql = """ + SELECT group_id, user_id FROM group_attestations_renewals + WHERE valid_until_ms <= ? + """ + txn.execute(sql, (valid_until_ms,)) + return self.db_pool.cursor_to_dict(txn) + + return self.db_pool.runInteraction( + "get_attestations_need_renewals", _get_attestations_need_renewals_txn + ) + + @defer.inlineCallbacks + def get_remote_attestation(self, group_id, user_id): + """Get the attestation that proves the remote agrees that the user is + in the group. + """ + row = yield self.db_pool.simple_select_one( + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcols=("valid_until_ms", "attestation_json"), + desc="get_remote_attestation", + allow_none=True, + ) + + now = int(self._clock.time_msec()) + if row and now < row["valid_until_ms"]: + return db_to_json(row["attestation_json"]) + + return None + + def get_joined_groups(self, user_id): + return self.db_pool.simple_select_onecol( + table="local_group_membership", + keyvalues={"user_id": user_id, "membership": "join"}, + retcol="group_id", + desc="get_joined_groups", + ) + + def get_all_groups_for_user(self, user_id, now_token): + def _get_all_groups_for_user_txn(txn): + sql = """ + SELECT group_id, type, membership, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND membership != 'leave' + AND stream_id <= ? + """ + txn.execute(sql, (user_id, now_token)) + return [ + { + "group_id": row[0], + "type": row[1], + "membership": row[2], + "content": db_to_json(row[3]), + } + for row in txn + ] + + return self.db_pool.runInteraction( + "get_all_groups_for_user", _get_all_groups_for_user_txn + ) + + def get_groups_changes_for_user(self, user_id, from_token, to_token): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_entity_changed( + user_id, from_token + ) + if not has_changed: + return defer.succeed([]) + + def _get_groups_changes_for_user_txn(txn): + sql = """ + SELECT group_id, membership, type, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (user_id, from_token, to_token)) + return [ + { + "group_id": group_id, + "membership": membership, + "type": gtype, + "content": db_to_json(content_json), + } + for group_id, membership, gtype, content_json in txn + ] + + return self.db_pool.runInteraction( + "get_groups_changes_for_user", _get_groups_changes_for_user_txn + ) + + async def get_all_groups_changes( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for groups replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + last_id = int(last_id) + has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) + + if not has_changed: + return [], current_id, False + + def _get_all_groups_changes_txn(txn): + sql = """ + SELECT stream_id, group_id, user_id, type, content + FROM local_group_updates + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + updates = [ + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) + for stream_id, group_id, user_id, gtype, content_json in txn + ] + + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_all_groups_changes", _get_all_groups_changes_txn + ) + + +class GroupServerStore(GroupServerWorkerStore): + def set_group_join_policy(self, group_id, join_policy): + """Set the join policy of a group. + + join_policy can be one of: + * "invite" + * "open" + """ + return self.db_pool.simple_update_one( + table="groups", + keyvalues={"group_id": group_id}, + updatevalues={"join_policy": join_policy}, + desc="set_group_join_policy", + ) + + def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): + return self.db_pool.runInteraction( + "add_room_to_summary", + self._add_room_to_summary_txn, + group_id, + room_id, + category_id, + order, + is_public, + ) + + def _add_room_to_summary_txn( + self, txn, group_id, room_id, category_id, order, is_public + ): + """Add (or update) room's entry in summary. + + Args: + group_id (str) + room_id (str) + category_id (str): If not None then adds the category to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the room at that position, e.g. + an order of 1 will put the room first. Otherwise, the room gets + added to the end. + """ + room_in_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + retcol="room_id", + allow_none=True, + ) + if not room_in_group: + raise SynapseError(400, "room not in group") + + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + else: + cat_exists = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + raise SynapseError(400, "Category doesn't exist") + + # TODO: Check category is part of summary already + cat_exists = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_summary_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + # If not, add it with an order larger than all others + txn.execute( + """ + INSERT INTO group_summary_room_categories + (group_id, category_id, cat_order) + SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 + FROM group_summary_room_categories + WHERE group_id = ? AND category_id = ? + """, + (group_id, category_id, group_id, category_id), + ) + + existing = self.db_pool.simple_select_one_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + "category_id": category_id, + }, + retcols=("room_order", "is_public"), + allow_none=True, + ) + + if order is not None: + # Shuffle other room orders that come after the given order + sql = """ + UPDATE group_summary_rooms SET room_order = room_order + 1 + WHERE group_id = ? AND category_id = ? AND room_order >= ? + """ + txn.execute(sql, (group_id, category_id, order)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms + WHERE group_id = ? AND category_id = ? + """ + txn.execute(sql, (group_id, category_id)) + (order,) = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["room_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self.db_pool.simple_update_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self.db_pool.simple_insert_txn( + txn, + table="group_summary_rooms", + values={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + "room_order": order, + "is_public": is_public, + }, + ) + + def remove_room_from_summary(self, group_id, room_id, category_id): + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + + return self.db_pool.simple_delete( + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + desc="remove_room_from_summary", + ) + + def upsert_group_category(self, group_id, category_id, profile, is_public): + """Add/update room category for group + """ + insertion_values = {} + update_values = {"category_id": category_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self.db_pool.simple_upsert( + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_category", + ) + + def remove_group_category(self, group_id, category_id): + return self.db_pool.simple_delete( + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + desc="remove_group_category", + ) + + def upsert_group_role(self, group_id, role_id, profile, is_public): + """Add/remove user role + """ + insertion_values = {} + update_values = {"role_id": role_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self.db_pool.simple_upsert( + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_role", + ) + + def remove_group_role(self, group_id, role_id): + return self.db_pool.simple_delete( + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + desc="remove_group_role", + ) + + def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): + return self.db_pool.runInteraction( + "add_user_to_summary", + self._add_user_to_summary_txn, + group_id, + user_id, + role_id, + order, + is_public, + ) + + def _add_user_to_summary_txn( + self, txn, group_id, user_id, role_id, order, is_public + ): + """Add (or update) user's entry in summary. + + Args: + group_id (str) + user_id (str) + role_id (str): If not None then adds the role to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the user at that position, e.g. + an order of 1 will put the user first. Otherwise, the user gets + added to the end. + """ + user_in_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + ) + if not user_in_group: + raise SynapseError(400, "user not in group") + + if role_id is None: + role_id = _DEFAULT_ROLE_ID + else: + role_exists = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + raise SynapseError(400, "Role doesn't exist") + + # TODO: Check role is part of the summary already + role_exists = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_summary_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + # If not, add it with an order larger than all others + txn.execute( + """ + INSERT INTO group_summary_roles + (group_id, role_id, role_order) + SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 + FROM group_summary_roles + WHERE group_id = ? AND role_id = ? + """, + (group_id, role_id, group_id, role_id), + ) + + existing = self.db_pool.simple_select_one_txn( + txn, + table="group_summary_users", + keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, + retcols=("user_order", "is_public"), + allow_none=True, + ) + + if order is not None: + # Shuffle other users orders that come after the given order + sql = """ + UPDATE group_summary_users SET user_order = user_order + 1 + WHERE group_id = ? AND role_id = ? AND user_order >= ? + """ + txn.execute(sql, (group_id, role_id, order)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users + WHERE group_id = ? AND role_id = ? + """ + txn.execute(sql, (group_id, role_id)) + (order,) = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["user_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self.db_pool.simple_update_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self.db_pool.simple_insert_txn( + txn, + table="group_summary_users", + values={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + "user_order": order, + "is_public": is_public, + }, + ) + + def remove_user_from_summary(self, group_id, user_id, role_id): + if role_id is None: + role_id = _DEFAULT_ROLE_ID + + return self.db_pool.simple_delete( + table="group_summary_users", + keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, + desc="remove_user_from_summary", + ) + + def add_group_invite(self, group_id, user_id): + """Record that the group server has invited a user + """ + return self.db_pool.simple_insert( + table="group_invites", + values={"group_id": group_id, "user_id": user_id}, + desc="add_group_invite", + ) + + def add_user_to_group( + self, + group_id, + user_id, + is_admin=False, + is_public=True, + local_attestation=None, + remote_attestation=None, + ): + """Add a user to the group server. + + Args: + group_id (str) + user_id (str) + is_admin (bool) + is_public (bool) + local_attestation (dict): The attestation the GS created to give + to the remote server. Optional if the user and group are on the + same server + remote_attestation (dict): The attestation given to GS by remote + server. Optional if the user and group are on the same server + """ + + def _add_user_to_group_txn(txn): + self.db_pool.simple_insert_txn( + txn, + table="group_users", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "is_public": is_public, + }, + ) + + self.db_pool.simple_delete_txn( + txn, + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + + if local_attestation: + self.db_pool.simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + }, + ) + if remote_attestation: + self.db_pool.simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + }, + ) + + return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) + + def remove_user_from_group(self, group_id, user_id): + def _remove_user_from_group_txn(txn): + self.db_pool.simple_delete_txn( + txn, + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="group_summary_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + + return self.db_pool.runInteraction( + "remove_user_from_group", _remove_user_from_group_txn + ) + + def add_room_to_group(self, group_id, room_id, is_public): + return self.db_pool.simple_insert( + table="group_rooms", + values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, + desc="add_room_to_group", + ) + + def update_room_in_group_visibility(self, group_id, room_id, is_public): + return self.db_pool.simple_update( + table="group_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + updatevalues={"is_public": is_public}, + desc="update_room_in_group_visibility", + ) + + def remove_room_from_group(self, group_id, room_id): + def _remove_room_from_group_txn(txn): + self.db_pool.simple_delete_txn( + txn, + table="group_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + ) + + self.db_pool.simple_delete_txn( + txn, + table="group_summary_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + ) + + return self.db_pool.runInteraction( + "remove_room_from_group", _remove_room_from_group_txn + ) + + def update_group_publicity(self, group_id, user_id, publicise): + """Update whether the user is publicising their membership of the group + """ + return self.db_pool.simple_update_one( + table="local_group_membership", + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={"is_publicised": publicise}, + desc="update_group_publicity", + ) + + @defer.inlineCallbacks + def register_user_group_membership( + self, + group_id, + user_id, + membership, + is_admin=False, + content={}, + local_attestation=None, + remote_attestation=None, + is_publicised=False, + ): + """Registers that a local user is a member of a (local or remote) group. + + Args: + group_id (str) + user_id (str) + membership (str) + is_admin (bool) + content (dict): Content of the membership, e.g. includes the inviter + if the user has been invited. + local_attestation (dict): If remote group then store the fact that we + have given out an attestation, else None. + remote_attestation (dict): If remote group then store the remote + attestation from the group, else None. + """ + + def _register_user_group_membership_txn(txn, next_id): + # TODO: Upsert? + self.db_pool.simple_delete_txn( + txn, + table="local_group_membership", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_insert_txn( + txn, + table="local_group_membership", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "membership": membership, + "is_publicised": is_publicised, + "content": json.dumps(content), + }, + ) + + self.db_pool.simple_insert_txn( + txn, + table="local_group_updates", + values={ + "stream_id": next_id, + "group_id": group_id, + "user_id": user_id, + "type": "membership", + "content": json.dumps( + {"membership": membership, "content": content} + ), + }, + ) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + + # TODO: Insert profile to ensure it comes down stream if its a join. + + if membership == "join": + if local_attestation: + self.db_pool.simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + }, + ) + if remote_attestation: + self.db_pool.simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + }, + ) + else: + self.db_pool.simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + + return next_id + + with self._group_updates_id_gen.get_next() as next_id: + res = yield self.db_pool.runInteraction( + "register_user_group_membership", + _register_user_group_membership_txn, + next_id, + ) + return res + + @defer.inlineCallbacks + def create_group( + self, group_id, user_id, name, avatar_url, short_description, long_description + ): + yield self.db_pool.simple_insert( + table="groups", + values={ + "group_id": group_id, + "name": name, + "avatar_url": avatar_url, + "short_description": short_description, + "long_description": long_description, + "is_public": True, + }, + desc="create_group", + ) + + @defer.inlineCallbacks + def update_group_profile(self, group_id, profile): + yield self.db_pool.simple_update_one( + table="groups", + keyvalues={"group_id": group_id}, + updatevalues=profile, + desc="update_group_profile", + ) + + def update_attestation_renewal(self, group_id, user_id, attestation): + """Update an attestation that we have renewed + """ + return self.db_pool.simple_update_one( + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, + desc="update_attestation_renewal", + ) + + def update_remote_attestion(self, group_id, user_id, attestation): + """Update an attestation that a remote has renewed + """ + return self.db_pool.simple_update_one( + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + "attestation_json": json.dumps(attestation), + }, + desc="update_remote_attestion", + ) + + def remove_attestation_renewal(self, group_id, user_id): + """Remove an attestation that we thought we should renew, but actually + shouldn't. Ideally this would never get called as we would never + incorrectly try and do attestations for local users on local groups. + + Args: + group_id (str) + user_id (str) + """ + return self.db_pool.simple_delete( + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + desc="remove_attestation_renewal", + ) + + def get_group_stream_token(self): + return self._group_updates_id_gen.get_current_token() + + def delete_group(self, group_id): + """Deletes a group fully from the database. + + Args: + group_id (str) + + Returns: + Deferred + """ + + def _delete_group_txn(txn): + tables = [ + "groups", + "group_users", + "group_invites", + "group_rooms", + "group_summary_rooms", + "group_summary_room_categories", + "group_room_categories", + "group_summary_users", + "group_summary_roles", + "group_roles", + "group_attestations_renewals", + "group_attestations_remote", + ] + + for table in tables: + self.db_pool.simple_delete_txn( + txn, table=table, keyvalues={"group_id": group_id} + ) + + return self.db_pool.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py new file mode 100644 index 0000000000..384e9c5eb0 --- /dev/null +++ b/synapse/storage/databases/main/keys.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging + +from signedjson.key import decode_verify_key_bytes + +from synapse.storage._base import SQLBaseStore +from synapse.storage.keys import FetchKeyResult +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter + +logger = logging.getLogger(__name__) + + +db_binary_type = memoryview + + +class KeyStore(SQLBaseStore): + """Persistence for signature verification keys + """ + + @cached() + def _get_server_verify_key(self, server_name_and_key_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" + ) + def get_server_verify_keys(self, server_name_and_key_ids): + """ + Args: + server_name_and_key_ids (iterable[Tuple[str, str]]): + iterable of (server_name, key-id) tuples to fetch keys for + + Returns: + Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: + map from (server_name, key_id) -> FetchKeyResult, or None if the key is + unknown + """ + keys = {} + + def _get_keys(txn, batch): + """Processes a batch of keys to fetch, and adds the result to `keys`.""" + + # batch_iter always returns tuples so it's safe to do len(batch) + sql = ( + "SELECT server_name, key_id, verify_key, ts_valid_until_ms " + "FROM server_signature_keys WHERE 1=0" + ) + " OR (server_name=? AND key_id=?)" * len(batch) + + txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) + + for row in txn: + server_name, key_id, key_bytes, ts_valid_until_ms = row + + if ts_valid_until_ms is None: + # Old keys may be stored with a ts_valid_until_ms of null, + # in which case we treat this as if it was set to `0`, i.e. + # it won't match key requests that define a minimum + # `ts_valid_until_ms`. + ts_valid_until_ms = 0 + + res = FetchKeyResult( + verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), + valid_until_ts=ts_valid_until_ms, + ) + keys[(server_name, key_id)] = res + + def _txn(txn): + for batch in batch_iter(server_name_and_key_ids, 50): + _get_keys(txn, batch) + return keys + + return self.db_pool.runInteraction("get_server_verify_keys", _txn) + + def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + """Stores NACL verification keys for remote servers. + Args: + from_server (str): Where the verification keys were looked up + ts_added_ms (int): The time to record that the key was added + verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + keys to be stored. Each entry is a triplet of + (server_name, key_id, key). + """ + key_values = [] + value_values = [] + invalidations = [] + for server_name, key_id, fetch_result in verify_keys: + key_values.append((server_name, key_id)) + value_values.append( + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) + ) + # invalidate takes a tuple corresponding to the params of + # _get_server_verify_key. _get_server_verify_key only takes one + # param, which is itself the 2-tuple (server_name, key_id). + invalidations.append((server_name, key_id)) + + def _invalidate(res): + f = self._get_server_verify_key.invalidate + for i in invalidations: + f((i,)) + return res + + return self.db_pool.runInteraction( + "store_server_verify_keys", + self.db_pool.simple_upsert_many_txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=key_values, + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=value_values, + ).addCallback(_invalidate) + + def store_server_keys_json( + self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes + ): + """Stores the JSON bytes for a set of keys from a server + The JSON should be signed by the originating server, the intermediate + server, and by this server. Updates the value for the + (server_name, key_id, from_server) triplet if one already existed. + Args: + server_name (str): The name of the server. + key_id (str): The identifer of the key this JSON is for. + from_server (str): The server this JSON was fetched from. + ts_now_ms (int): The time now in milliseconds. + ts_valid_until_ms (int): The time when this json stops being valid. + key_json (bytes): The encoded JSON. + """ + return self.db_pool.simple_upsert( + table="server_keys_json", + keyvalues={ + "server_name": server_name, + "key_id": key_id, + "from_server": from_server, + }, + values={ + "server_name": server_name, + "key_id": key_id, + "from_server": from_server, + "ts_added_ms": ts_now_ms, + "ts_valid_until_ms": ts_expires_ms, + "key_json": db_binary_type(key_json_bytes), + }, + desc="store_server_keys_json", + ) + + def get_server_keys_json(self, server_keys): + """Retrive the key json for a list of server_keys and key ids. + If no keys are found for a given server, key_id and source then + that server, key_id, and source triplet entry will be an empty list. + The JSON is returned as a byte array so that it can be efficiently + used in an HTTP response. + Args: + server_keys (list): List of (server_name, key_id, source) triplets. + Returns: + Deferred[dict[Tuple[str, str, str|None], list[dict]]]: + Dict mapping (server_name, key_id, source) triplets to lists of dicts + """ + + def _get_server_keys_json_txn(txn): + results = {} + for server_name, key_id, from_server in server_keys: + keyvalues = {"server_name": server_name} + if key_id is not None: + keyvalues["key_id"] = key_id + if from_server is not None: + keyvalues["from_server"] = from_server + rows = self.db_pool.simple_select_list_txn( + txn, + "server_keys_json", + keyvalues=keyvalues, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + ) + results[(server_name, key_id, from_server)] = rows + return results + + return self.db_pool.runInteraction( + "get_server_keys_json", _get_server_keys_json_txn + ) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py new file mode 100644 index 0000000000..80fc1cd009 --- /dev/null +++ b/synapse/storage/databases/main/media_repository.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool + + +class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(MediaRepositoryBackgroundUpdateStore, self).__init__( + database, db_conn, hs + ) + + self.db_pool.updates.register_background_index_update( + update_name="local_media_repository_url_idx", + index_name="local_media_repository_url_idx", + table="local_media_repository", + columns=["created_ts"], + where_clause="url_cache IS NOT NULL", + ) + + +class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): + """Persistence for attachments and avatars""" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(MediaRepositoryStore, self).__init__(database, db_conn, hs) + + def get_local_media(self, media_id): + """Get the metadata for a local piece of media + Returns: + None if the media_id doesn't exist. + """ + return self.db_pool.simple_select_one( + "local_media_repository", + {"media_id": media_id}, + ( + "media_type", + "media_length", + "upload_name", + "created_ts", + "quarantined_by", + "url_cache", + ), + allow_none=True, + desc="get_local_media", + ) + + def store_local_media( + self, + media_id, + media_type, + time_now_ms, + upload_name, + media_length, + user_id, + url_cache=None, + ): + return self.db_pool.simple_insert( + "local_media_repository", + { + "media_id": media_id, + "media_type": media_type, + "created_ts": time_now_ms, + "upload_name": upload_name, + "media_length": media_length, + "user_id": user_id.to_string(), + "url_cache": url_cache, + }, + desc="store_local_media", + ) + + def mark_local_media_as_safe(self, media_id: str): + """Mark a local media as safe from quarantining.""" + return self.db_pool.simple_update_one( + table="local_media_repository", + keyvalues={"media_id": media_id}, + updatevalues={"safe_from_quarantine": True}, + desc="mark_local_media_as_safe", + ) + + def get_url_cache(self, url, ts): + """Get the media_id and ts for a cached URL as of the given timestamp + Returns: + None if the URL isn't cached. + """ + + def get_url_cache_txn(txn): + # get the most recently cached result (relative to the given ts) + sql = ( + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts <= ?" + " ORDER BY download_ts DESC LIMIT 1" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row: + # ...or if we've requested a timestamp older than the oldest + # copy in the cache, return the oldest copy (if any) + sql = ( + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts > ?" + " ORDER BY download_ts ASC LIMIT 1" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row: + return None + + return dict( + zip( + ( + "response_code", + "etag", + "expires_ts", + "og", + "media_id", + "download_ts", + ), + row, + ) + ) + + return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) + + def store_url_cache( + self, url, response_code, etag, expires_ts, og, media_id, download_ts + ): + return self.db_pool.simple_insert( + "local_media_repository_url_cache", + { + "url": url, + "response_code": response_code, + "etag": etag, + "expires_ts": expires_ts, + "og": og, + "media_id": media_id, + "download_ts": download_ts, + }, + desc="store_url_cache", + ) + + def get_local_media_thumbnails(self, media_id): + return self.db_pool.simple_select_list( + "local_media_repository_thumbnails", + {"media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + ), + desc="get_local_media_thumbnails", + ) + + def store_local_thumbnail( + self, + media_id, + thumbnail_width, + thumbnail_height, + thumbnail_type, + thumbnail_method, + thumbnail_length, + ): + return self.db_pool.simple_insert( + "local_media_repository_thumbnails", + { + "media_id": media_id, + "thumbnail_width": thumbnail_width, + "thumbnail_height": thumbnail_height, + "thumbnail_method": thumbnail_method, + "thumbnail_type": thumbnail_type, + "thumbnail_length": thumbnail_length, + }, + desc="store_local_thumbnail", + ) + + def get_cached_remote_media(self, origin, media_id): + return self.db_pool.simple_select_one( + "remote_media_cache", + {"media_origin": origin, "media_id": media_id}, + ( + "media_type", + "media_length", + "upload_name", + "created_ts", + "filesystem_id", + "quarantined_by", + ), + allow_none=True, + desc="get_cached_remote_media", + ) + + def store_cached_remote_media( + self, + origin, + media_id, + media_type, + media_length, + time_now_ms, + upload_name, + filesystem_id, + ): + return self.db_pool.simple_insert( + "remote_media_cache", + { + "media_origin": origin, + "media_id": media_id, + "media_type": media_type, + "media_length": media_length, + "created_ts": time_now_ms, + "upload_name": upload_name, + "filesystem_id": filesystem_id, + "last_access_ts": time_now_ms, + }, + desc="store_cached_remote_media", + ) + + def update_cached_last_access_time(self, local_media, remote_media, time_ms): + """Updates the last access time of the given media + + Args: + local_media (iterable[str]): Set of media_ids + remote_media (iterable[(str, str)]): Set of (server_name, media_id) + time_ms: Current time in milliseconds + """ + + def update_cache_txn(txn): + sql = ( + "UPDATE remote_media_cache SET last_access_ts = ?" + " WHERE media_origin = ? AND media_id = ?" + ) + + txn.executemany( + sql, + ( + (time_ms, media_origin, media_id) + for media_origin, media_id in remote_media + ), + ) + + sql = ( + "UPDATE local_media_repository SET last_access_ts = ?" + " WHERE media_id = ?" + ) + + txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) + + return self.db_pool.runInteraction( + "update_cached_last_access_time", update_cache_txn + ) + + def get_remote_media_thumbnails(self, origin, media_id): + return self.db_pool.simple_select_list( + "remote_media_cache_thumbnails", + {"media_origin": origin, "media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + "filesystem_id", + ), + desc="get_remote_media_thumbnails", + ) + + def store_remote_media_thumbnail( + self, + origin, + media_id, + filesystem_id, + thumbnail_width, + thumbnail_height, + thumbnail_type, + thumbnail_method, + thumbnail_length, + ): + return self.db_pool.simple_insert( + "remote_media_cache_thumbnails", + { + "media_origin": origin, + "media_id": media_id, + "thumbnail_width": thumbnail_width, + "thumbnail_height": thumbnail_height, + "thumbnail_method": thumbnail_method, + "thumbnail_type": thumbnail_type, + "thumbnail_length": thumbnail_length, + "filesystem_id": filesystem_id, + }, + desc="store_remote_media_thumbnail", + ) + + def get_remote_media_before(self, before_ts): + sql = ( + "SELECT media_origin, media_id, filesystem_id" + " FROM remote_media_cache" + " WHERE last_access_ts < ?" + ) + + return self.db_pool.execute( + "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts + ) + + def delete_remote_media(self, media_origin, media_id): + def delete_remote_media_txn(txn): + self.db_pool.simple_delete_txn( + txn, + "remote_media_cache", + keyvalues={"media_origin": media_origin, "media_id": media_id}, + ) + self.db_pool.simple_delete_txn( + txn, + "remote_media_cache_thumbnails", + keyvalues={"media_origin": media_origin, "media_id": media_id}, + ) + + return self.db_pool.runInteraction( + "delete_remote_media", delete_remote_media_txn + ) + + def get_expired_url_cache(self, now_ts): + sql = ( + "SELECT media_id FROM local_media_repository_url_cache" + " WHERE expires_ts < ?" + " ORDER BY expires_ts ASC" + " LIMIT 500" + ) + + def _get_expired_url_cache_txn(txn): + txn.execute(sql, (now_ts,)) + return [row[0] for row in txn] + + return self.db_pool.runInteraction( + "get_expired_url_cache", _get_expired_url_cache_txn + ) + + async def delete_url_cache(self, media_ids): + if len(media_ids) == 0: + return + + sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" + + def _delete_url_cache_txn(txn): + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return await self.db_pool.runInteraction( + "delete_url_cache", _delete_url_cache_txn + ) + + def get_url_cache_media_before(self, before_ts): + sql = ( + "SELECT media_id FROM local_media_repository" + " WHERE created_ts < ? AND url_cache IS NOT NULL" + " ORDER BY created_ts ASC" + " LIMIT 500" + ) + + def _get_url_cache_media_before_txn(txn): + txn.execute(sql, (before_ts,)) + return [row[0] for row in txn] + + return self.db_pool.runInteraction( + "get_url_cache_media_before", _get_url_cache_media_before_txn + ) + + async def delete_url_cache_media(self, media_ids): + if len(media_ids) == 0: + return + + def _delete_url_cache_media_txn(txn): + sql = "DELETE FROM local_media_repository WHERE media_id = ?" + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return await self.db_pool.runInteraction( + "delete_url_cache_media", _delete_url_cache_media_txn + ) diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py new file mode 100644 index 0000000000..baa7a5092a --- /dev/null +++ b/synapse/storage/databases/main/metrics.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import typing +from collections import Counter + +from twisted.internet import defer + +from synapse.metrics import BucketCollector +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.event_push_actions import ( + EventPushActionsWorkerStore, +) + + +class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): + """Functions to pull various metrics from the DB, for e.g. phone home + stats and prometheus metrics. + """ + + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + # Collect metrics on the number of forward extremities that exist. + # Counter of number of extremities to count + self._current_forward_extremities_amount = ( + Counter() + ) # type: typing.Counter[int] + + BucketCollector( + "synapse_forward_extremities", + lambda: self._current_forward_extremities_amount, + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], + ) + + # Read the extrems every 60 minutes + def read_forward_extremities(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "read_forward_extremities", self._read_forward_extremities + ) + + hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + + async def _read_forward_extremities(self): + def fetch(txn): + txn.execute( + """ + select count(*) c from event_forward_extremities + group by room_id + """ + ) + return txn.fetchall() + + res = await self.db_pool.runInteraction("read_forward_extremities", fetch) + self._current_forward_extremities_amount = Counter([x[0] for x in res]) + + @defer.inlineCallbacks + def count_daily_messages(self): + """ + Returns an estimate of the number of messages sent in the last day. + + If it has been significantly less or more than one day since the last + call to this function, it will return None. + """ + + def _count_messages(txn): + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + ret = yield self.db_pool.runInteraction("count_messages", _count_messages) + return ret + + @defer.inlineCallbacks + def count_daily_sent_messages(self): + def _count_messages(txn): + # This is good enough as if you have silly characters in your own + # hostname then thats your own fault. + like_clause = "%:" + self.hs.hostname + + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND sender LIKE ? + AND stream_ordering > ? + """ + + txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) + (count,) = txn.fetchone() + return count + + ret = yield self.db_pool.runInteraction( + "count_daily_sent_messages", _count_messages + ) + return ret + + @defer.inlineCallbacks + def count_daily_active_rooms(self): + def _count(txn): + sql = """ + SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count) + return ret diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py new file mode 100644 index 0000000000..02b01d9619 --- /dev/null +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import List + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + +# Number of msec of granularity to store the monthly_active_user timestamp +# This means it is not necessary to update the table on every request +LAST_SEEN_GRANULARITY = 60 * 60 * 1000 + + +class MonthlyActiveUsersWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) + self._clock = hs.get_clock() + self.hs = hs + + @cached(num_args=0) + def get_monthly_active_count(self): + """Generates current count of monthly active users + + Returns: + Defered[int]: Number of current monthly active users + """ + + def _count_users(txn): + sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + txn.execute(sql) + (count,) = txn.fetchone() + return count + + return self.db_pool.runInteraction("count_users", _count_users) + + @cached(num_args=0) + def get_monthly_active_count_by_service(self): + """Generates current count of monthly active users broken down by service. + A service is typically an appservice but also includes native matrix users. + Since the `monthly_active_users` table is populated from the `user_ips` table + `config.track_appservice_user_ips` must be set to `true` for this + method to return anything other than native matrix users. + + Returns: + Deferred[dict]: dict that includes a mapping between app_service_id + and the number of occurrences. + + """ + + def _count_users_by_service(txn): + sql = """ + SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users ON monthly_active_users.user_id=users.name + GROUP BY appservice_id; + """ + + txn.execute(sql) + result = txn.fetchall() + return dict(result) + + return self.db_pool.runInteraction( + "count_users_by_service", _count_users_by_service + ) + + async def get_registered_reserved_users(self) -> List[str]: + """Of the reserved threepids defined in config, retrieve those that are associated + with registered users + + Returns: + User IDs of actual users that are reserved + """ + users = [] + + for tp in self.hs.config.mau_limits_reserved_threepids[ + : self.hs.config.max_mau_value + ]: + user_id = await self.hs.get_datastore().get_user_id_by_threepid( + tp["medium"], tp["address"] + ) + if user_id: + users.append(user_id) + + return users + + @cached(num_args=1) + def user_last_seen_monthly_active(self, user_id): + """ + Checks if a given user is part of the monthly active user group + Arguments: + user_id (str): user to add/update + Return: + Deferred[int] : timestamp since last seen, None if never seen + + """ + + return self.db_pool.simple_select_one_onecol( + table="monthly_active_users", + keyvalues={"user_id": user_id}, + retcol="timestamp", + allow_none=True, + desc="user_last_seen_monthly_active", + ) + + +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) + + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._mau_stats_only = hs.config.mau_stats_only + self._max_mau_value = hs.config.max_mau_value + + # Do not add more reserved users than the total allowable number + # cur = LoggingTransaction( + self.db_pool.new_transaction( + db_conn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + ) + + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ + + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? + + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) + + if user_id: + is_support = self.is_support_user_txn(txn, user_id) + if not is_support: + # We do this manually here to avoid hitting #6791 + self.db_pool.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + else: + logger.warning("mau limit reserved threepid %s not found in db" % tp) + + async def reap_monthly_active_users(self): + """Cleans out monthly active user table to ensure that no stale + entries exist. + """ + + def _reap_users(txn, reserved_users): + """ + Args: + reserved_users (tuple): reserved users to preserve + """ + + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + + in_clause, in_clause_args = make_in_list_sql_clause( + self.database_engine, "user_id", reserved_users + ) + + txn.execute( + "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s" + % (in_clause,), + [thirty_days_ago] + in_clause_args, + ) + + if self._limit_usage_by_mau: + # If MAU user count still exceeds the MAU threshold, then delete on + # a least recently active basis. + # Note it is not possible to write this query using OFFSET due to + # incompatibilities in how sqlite and postgres support the feature. + # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present, + # while Postgres does not require 'LIMIT', but also does not support + # negative LIMIT values. So there is no way to write it that both can + # support + + # Limit must be >= 0 for postgres + num_of_non_reserved_users_to_remove = max( + self._max_mau_value - len(reserved_users), 0 + ) + + # It is important to filter reserved users twice to guard + # against the case where the reserved user is present in the + # SELECT, meaning that a legitimate mau is deleted. + sql = """ + DELETE FROM monthly_active_users + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + WHERE NOT %s + ORDER BY timestamp DESC + LIMIT ? + ) + AND NOT %s + """ % ( + in_clause, + in_clause, + ) + + query_args = ( + in_clause_args + + [num_of_non_reserved_users_to_remove] + + in_clause_args + ) + txn.execute(sql, query_args) + + # It seems poor to invalidate the whole cache. Postgres supports + # 'Returning' which would allow me to invalidate only the + # specific users, but sqlite has no way to do this and instead + # I would need to SELECT and the DELETE which without locking + # is racy. + # Have resolved to invalidate the whole cache for now and do + # something about it if and when the perf becomes significant + self._invalidate_all_cache_and_stream( + txn, self.user_last_seen_monthly_active + ) + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + + reserved_users = await self.get_registered_reserved_users() + await self.db_pool.runInteraction( + "reap_monthly_active_users", _reap_users, reserved_users + ) + + @defer.inlineCallbacks + def upsert_monthly_active_user(self, user_id): + """Updates or inserts the user into the monthly active user table, which + is used to track the current MAU usage of the server + + Args: + user_id (str): user to add/update + + Returns: + Deferred + """ + # Support user never to be included in MAU stats. Note I can't easily call this + # from upsert_monthly_active_user_txn because then I need a _txn form of + # is_support_user which is complicated because I want to cache the result. + # Therefore I call it here and ignore the case where + # upsert_monthly_active_user_txn is called directly from + # _initialise_reserved_users reasoning that it would be very strange to + # include a support user in this context. + + is_support = yield self.is_support_user(user_id) + if is_support: + return + + yield self.db_pool.runInteraction( + "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id + ) + + def upsert_monthly_active_user_txn(self, txn, user_id): + """Updates or inserts monthly active user member + + We consciously do not call is_support_txn from this method because it + is not possible to cache the response. is_support_txn will be false in + almost all cases, so it seems reasonable to call it only for + upsert_monthly_active_user and to call is_support_txn manually + for cases where upsert_monthly_active_user_txn is called directly, + like _initialise_reserved_users + + In short, don't call this method with support users. (Support users + should not appear in the MAU stats). + + Args: + txn (cursor): + user_id (str): user to add/update + + Returns: + bool: True if a new entry was created, False if an + existing one was updated. + """ + + # Am consciously deciding to lock the table on the basis that is ought + # never be a big table and alternative approaches (batching multiple + # upserts into a single txn) introduced a lot of extra complexity. + # See https://github.com/matrix-org/synapse/issues/3854 for more + is_insert = self.db_pool.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + self._invalidate_cache_and_stream( + txn, self.get_monthly_active_count_by_service, () + ) + self._invalidate_cache_and_stream( + txn, self.user_last_seen_monthly_active, (user_id,) + ) + + return is_insert + + @defer.inlineCallbacks + def populate_monthly_active_users(self, user_id): + """Checks on the state of monthly active user limits and optionally + add the user to the monthly active tables + + Args: + user_id(str): the user_id to query + """ + if self._limit_usage_by_mau or self._mau_stats_only: + # Trial users and guests should not be included as part of MAU group + is_guest = yield self.is_guest(user_id) + if is_guest: + return + is_trial = yield self.is_trial_user(user_id) + if is_trial: + return + + last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) + now = self.hs.get_clock().time_msec() + + # We want to reduce to the total number of db writes, and are happy + # to trade accuracy of timestamp in order to lighten load. This means + # We always insert new users (where MAU threshold has not been reached), + # but only update if we have not previously seen the user for + # LAST_SEEN_GRANULARITY ms + if last_seen_timestamp is None: + # In the case where mau_stats_only is True and limit_usage_by_mau is + # False, there is no point in checking get_monthly_active_count - it + # adds no value and will break the logic if max_mau_value is exceeded. + if not self._limit_usage_by_mau: + yield self.upsert_monthly_active_user(user_id) + else: + count = yield self.get_monthly_active_count() + if count < self._max_mau_value: + yield self.upsert_monthly_active_user(user_id) + elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: + yield self.upsert_monthly_active_user(user_id) diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py new file mode 100644 index 0000000000..dcd1ff911a --- /dev/null +++ b/synapse/storage/databases/main/openid.py @@ -0,0 +1,33 @@ +from synapse.storage._base import SQLBaseStore + + +class OpenIdStore(SQLBaseStore): + def insert_open_id_token(self, token, ts_valid_until_ms, user_id): + return self.db_pool.simple_insert( + table="open_id_tokens", + values={ + "token": token, + "ts_valid_until_ms": ts_valid_until_ms, + "user_id": user_id, + }, + desc="insert_open_id_token", + ) + + def get_user_id_for_open_id_token(self, token, ts_now_ms): + def get_user_id_for_token_txn(txn): + sql = ( + "SELECT user_id FROM open_id_tokens" + " WHERE token = ? AND ? <= ts_valid_until_ms" + ) + + txn.execute(sql, (token, ts_now_ms)) + + rows = txn.fetchall() + if not rows: + return None + else: + return rows[0][0] + + return self.db_pool.runInteraction( + "get_user_id_for_token", get_user_id_for_token_txn + ) diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py new file mode 100644 index 0000000000..99e66dc6e9 --- /dev/null +++ b/synapse/storage/databases/main/presence.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.presence import UserPresenceState +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter + + +class PresenceStore(SQLBaseStore): + @defer.inlineCallbacks + def update_presence(self, presence_states): + stream_ordering_manager = self._presence_id_gen.get_next_mult( + len(presence_states) + ) + + with stream_ordering_manager as stream_orderings: + yield self.db_pool.runInteraction( + "update_presence", + self._update_presence_txn, + stream_orderings, + presence_states, + ) + + return stream_orderings[-1], self._presence_id_gen.get_current_token() + + def _update_presence_txn(self, txn, stream_orderings, presence_states): + for stream_id, state in zip(stream_orderings, presence_states): + txn.call_after( + self.presence_stream_cache.entity_has_changed, state.user_id, stream_id + ) + txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) + + # Actually insert new rows + self.db_pool.simple_insert_many_txn( + txn, + table="presence_stream", + values=[ + { + "stream_id": stream_id, + "user_id": state.user_id, + "state": state.state, + "last_active_ts": state.last_active_ts, + "last_federation_update_ts": state.last_federation_update_ts, + "last_user_sync_ts": state.last_user_sync_ts, + "status_msg": state.status_msg, + "currently_active": state.currently_active, + } + for stream_id, state in zip(stream_orderings, presence_states) + ], + ) + + # Delete old rows to stop database from getting really big + sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " + + for states in batch_iter(presence_states, 50): + clause, args = make_in_list_sql_clause( + self.database_engine, "user_id", [s.user_id for s in states] + ) + txn.execute(sql + clause, [stream_id] + list(args)) + + async def get_all_presence_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for presence replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_presence_updates_txn(txn): + sql = """ + SELECT stream_id, user_id, state, last_active_ts, + last_federation_update_ts, last_user_sync_ts, + status_msg, + currently_active + FROM presence_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + updates = [(row[0], row[1:]) for row in txn] + + upper_bound = current_id + limited = False + if len(updates) >= limit: + upper_bound = updates[-1][0] + limited = True + + return updates, upper_bound, limited + + return await self.db_pool.runInteraction( + "get_all_presence_updates", get_all_presence_updates_txn + ) + + @cached() + def _get_presence_for_user(self, user_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_presence_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def get_presence_for_users(self, user_ids): + rows = yield self.db_pool.simple_select_many_batch( + table="presence_stream", + column="user_id", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + desc="get_presence_for_users", + ) + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return {row["user_id"]: UserPresenceState(**row) for row in rows} + + def get_current_presence_token(self): + return self._presence_id_gen.get_current_token() + + def allow_presence_visible(self, observed_localpart, observer_userid): + return self.db_pool.simple_insert( + table="presence_allow_inbound", + values={ + "observed_user_id": observed_localpart, + "observer_user_id": observer_userid, + }, + desc="allow_presence_visible", + or_ignore=True, + ) + + def disallow_presence_visible(self, observed_localpart, observer_userid): + return self.db_pool.simple_delete_one( + table="presence_allow_inbound", + keyvalues={ + "observed_user_id": observed_localpart, + "observer_user_id": observer_userid, + }, + desc="disallow_presence_visible", + ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py new file mode 100644 index 0000000000..4a4f2cb385 --- /dev/null +++ b/synapse/storage/databases/main/profile.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore +from synapse.storage.databases.main.roommember import ProfileInfo + + +class ProfileWorkerStore(SQLBaseStore): + @defer.inlineCallbacks + def get_profileinfo(self, user_localpart): + try: + profile = yield self.db_pool.simple_select_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + ) + except StoreError as e: + if e.code == 404: + # no match + return ProfileInfo(None, None) + else: + raise + + return ProfileInfo( + avatar_url=profile["avatar_url"], display_name=profile["displayname"] + ) + + def get_profile_displayname(self, user_localpart): + return self.db_pool.simple_select_one_onecol( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcol="displayname", + desc="get_profile_displayname", + ) + + def get_profile_avatar_url(self, user_localpart): + return self.db_pool.simple_select_one_onecol( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcol="avatar_url", + desc="get_profile_avatar_url", + ) + + def get_from_remote_profile_cache(self, user_id): + return self.db_pool.simple_select_one( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + retcols=("displayname", "avatar_url"), + allow_none=True, + desc="get_from_remote_profile_cache", + ) + + def create_profile(self, user_localpart): + return self.db_pool.simple_insert( + table="profiles", values={"user_id": user_localpart}, desc="create_profile" + ) + + def set_profile_displayname(self, user_localpart, new_displayname): + return self.db_pool.simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"displayname": new_displayname}, + desc="set_profile_displayname", + ) + + def set_profile_avatar_url(self, user_localpart, new_avatar_url): + return self.db_pool.simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"avatar_url": new_avatar_url}, + desc="set_profile_avatar_url", + ) + + +class ProfileStore(ProfileWorkerStore): + def add_remote_profile_cache(self, user_id, displayname, avatar_url): + """Ensure we are caching the remote user's profiles. + + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + return self.db_pool.simple_upsert( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", + ) + + def update_remote_profile_cache(self, user_id, displayname, avatar_url): + return self.db_pool.simple_update( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + updatevalues={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="update_remote_profile_cache", + ) + + @defer.inlineCallbacks + def maybe_delete_remote_profile_cache(self, user_id): + """Check if we still care about the remote user's profile, and if we + don't then remove their profile from the cache + """ + subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + if not subscribed: + yield self.db_pool.simple_delete( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + desc="delete_remote_profile_cache", + ) + + def get_remote_profile_cache_entries_that_expire(self, last_checked): + """Get all users who haven't been checked since `last_checked` + """ + + def _get_remote_profile_cache_entries_that_expire_txn(txn): + sql = """ + SELECT user_id, displayname, avatar_url + FROM remote_profile_cache + WHERE last_check < ? + """ + + txn.execute(sql, (last_checked,)) + + return self.db_pool.cursor_to_dict(txn) + + return self.db_pool.runInteraction( + "get_remote_profile_cache_entries_that_expire", + _get_remote_profile_cache_entries_that_expire_txn, + ) + + @defer.inlineCallbacks + def is_subscribed_remote_profile_for_user(self, user_id): + """Check whether we are interested in a remote user's profile. + """ + res = yield self.db_pool.simple_select_one_onecol( + table="group_users", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + return True + + res = yield self.db_pool.simple_select_one_onecol( + table="group_invites", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + return True diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py new file mode 100644 index 0000000000..3526b6fd66 --- /dev/null +++ b/synapse/storage/databases/main/purge_events.py @@ -0,0 +1,400 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Tuple + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore +from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.types import RoomStreamToken + +logger = logging.getLogger(__name__) + + +class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): + def purge_history(self, room_id, token, delete_local_events): + """Deletes room history before a certain point + + Args: + room_id (str): + + token (str): A topological token to delete events before + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). + + Returns: + Deferred[set[int]]: The set of state groups that are referenced by + deleted events. + """ + + return self.db_pool.runInteraction( + "purge_history", + self._purge_history_txn, + room_id, + token, + delete_local_events, + ) + + def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): + token = RoomStreamToken.parse(token_str) + + # Tables that should be pruned: + # event_auth + # event_backward_extremities + # event_edges + # event_forward_extremities + # event_json + # event_push_actions + # event_reference_hashes + # event_relations + # event_search + # event_to_state_groups + # events + # rejections + # room_depth + # state_groups + # state_groups_state + + # we will build a temporary table listing the events so that we don't + # have to keep shovelling the list back and forth across the + # connection. Annoyingly the python sqlite driver commits the + # transaction on CREATE, so let's do this first. + # + # furthermore, we might already have the table from a previous (failed) + # purge attempt, so let's drop the table first. + + txn.execute("DROP TABLE IF EXISTS events_to_purge") + + txn.execute( + "CREATE TEMPORARY TABLE events_to_purge (" + " event_id TEXT NOT NULL," + " should_delete BOOLEAN NOT NULL" + ")" + ) + + # First ensure that we're not about to delete all the forward extremeties + txn.execute( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "AND e.room_id = f.room_id " + "WHERE f.room_id = ?", + (room_id,), + ) + rows = txn.fetchall() + max_depth = max(row[1] for row in rows) + + if max_depth < token.topological: + # We need to ensure we don't delete all the events from the database + # otherwise we wouldn't be able to send any events (due to not + # having any backwards extremeties) + raise SynapseError( + 400, "topological_ordering is greater than forward extremeties" + ) + + logger.info("[purge] looking for events to delete") + + should_delete_expr = "state_key IS NULL" + should_delete_params = () # type: Tuple[Any, ...] + if not delete_local_events: + should_delete_expr += " AND event_id NOT LIKE ?" + + # We include the parameter twice since we use the expression twice + should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) + + should_delete_params += (room_id, token.topological) + + # Note that we insert events that are outliers and aren't going to be + # deleted, as nothing will happen to them. + txn.execute( + "INSERT INTO events_to_purge" + " SELECT event_id, %s" + " FROM events AS e LEFT JOIN state_events USING (event_id)" + " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" + % (should_delete_expr, should_delete_expr), + should_delete_params, + ) + + # We create the indices *after* insertion as that's a lot faster. + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)" + ) + + # We do joins against events_to_purge for e.g. calculating state + # groups to purge, etc., so lets make an index. + txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)") + + txn.execute("SELECT event_id, should_delete FROM events_to_purge") + event_rows = txn.fetchall() + logger.info( + "[purge] found %i events before cutoff, of which %i can be deleted", + len(event_rows), + sum(1 for e in event_rows if e[1]), + ) + + logger.info("[purge] Finding new backward extremities") + + # We calculate the new entries for the backward extremeties by finding + # events to be purged that are pointed to by events we're not going to + # purge. + txn.execute( + "SELECT DISTINCT e.event_id FROM events_to_purge AS e" + " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" + " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" + " WHERE ep2.event_id IS NULL" + ) + new_backwards_extrems = txn.fetchall() + + logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) + + txn.execute( + "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) + ) + + # Update backward extremeties + txn.executemany( + "INSERT INTO event_backward_extremities (room_id, event_id)" + " VALUES (?, ?)", + [(room_id, event_id) for event_id, in new_backwards_extrems], + ) + + logger.info("[purge] finding state groups referenced by deleted events") + + # Get all state groups that are referenced by events that are to be + # deleted. + txn.execute( + """ + SELECT DISTINCT state_group FROM events_to_purge + INNER JOIN event_to_state_groups USING (event_id) + """ + ) + + referenced_state_groups = {sg for sg, in txn} + logger.info( + "[purge] found %i referenced state groups", len(referenced_state_groups) + ) + + logger.info("[purge] removing events from event_to_state_groups") + txn.execute( + "DELETE FROM event_to_state_groups " + "WHERE event_id IN (SELECT event_id from events_to_purge)" + ) + for event_id, _ in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + + # Delete all remote non-state events + for table in ( + "events", + "event_json", + "event_auth", + "event_edges", + "event_forward_extremities", + "event_reference_hashes", + "event_relations", + "event_search", + "rejections", + ): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,) + ) + + # event_push_actions lacks an index on event_id, and has one on + # (room_id, event_id) instead. + for table in ("event_push_actions",): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE room_id = ? AND event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + (room_id,), + ) + + # Mark all state and own events as outliers + logger.info("[purge] marking remaining events as outliers") + txn.execute( + "UPDATE events SET outlier = ?" + " WHERE event_id IN (" + " SELECT event_id FROM events_to_purge " + " WHERE NOT should_delete" + ")", + (True,), + ) + + # synapse tries to take out an exclusive lock on room_depth whenever it + # persists events (because upsert), and once we run this update, we + # will block that for the rest of our transaction. + # + # So, let's stick it at the end so that we don't block event + # persistence. + # + # We do this by calculating the minimum depth of the backwards + # extremities. However, the events in event_backward_extremities + # are ones we don't have yet so we need to look at the events that + # point to it via event_edges table. + txn.execute( + """ + SELECT COALESCE(MIN(depth), 0) + FROM event_backward_extremities AS eb + INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id + INNER JOIN events AS e ON e.event_id = eg.event_id + WHERE eb.room_id = ? + """, + (room_id,), + ) + (min_depth,) = txn.fetchone() + + logger.info("[purge] updating room_depth to %d", min_depth) + + txn.execute( + "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", + (min_depth, room_id), + ) + + # finally, drop the temp table. this will commit the txn in sqlite, + # so make sure to keep this actually last. + txn.execute("DROP TABLE events_to_purge") + + logger.info("[purge] done") + + return referenced_state_groups + + def purge_room(self, room_id): + """Deletes all record of a room + + Args: + room_id (str) + + Returns: + Deferred[List[int]]: The list of state groups to delete. + """ + + return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id) + + def _purge_room_txn(self, txn, room_id): + # First we fetch all the state groups that should be deleted, before + # we delete that information. + txn.execute( + """ + SELECT DISTINCT state_group FROM events + INNER JOIN event_to_state_groups USING(event_id) + WHERE events.room_id = ? + """, + (room_id,), + ) + + state_groups = [row[0] for row in txn] + + # Now we delete tables which lack an index on room_id but have one on event_id + for table in ( + "event_auth", + "event_edges", + "event_push_actions_staging", + "event_reference_hashes", + "event_relations", + "event_to_state_groups", + "redactions", + "rejections", + "state_events", + ): + logger.info("[purge] removing %s from %s", room_id, table) + + txn.execute( + """ + DELETE FROM %s WHERE event_id IN ( + SELECT event_id FROM events WHERE room_id=? + ) + """ + % (table,), + (room_id,), + ) + + # and finally, the tables with an index on room_id (or no useful index) + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + # no useful index, but let's clear them anyway + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "room_account_data", + "room_tags", + "local_current_membership", + ): + logger.info("[purge] removing %s from %s", room_id, table) + txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) + + # Other tables we do NOT need to clear out: + # + # - blocked_rooms + # This is important, to make sure that we don't accidentally rejoin a blocked + # room after it was purged + # + # - user_directory + # This has a room_id column, but it is unused + # + + # Other tables that we might want to consider clearing out include: + # + # - event_reports + # Given that these are intended for abuse management my initial + # inclination is to leave them in place. + # + # - current_state_delta_stream + # - ex_outlier_stream + # - room_tags_revisions + # The problem with these is that they are largeish and there is no room_id + # index on them. In any case we should be clearing out 'stream' tables + # periodically anyway (#5888) + + # TODO: we could probably usefully do a bunch of cache invalidation here + + logger.info("[purge] done") + + return state_groups diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py new file mode 100644 index 0000000000..97cc12931d --- /dev/null +++ b/synapse/storage/databases/main/push_rule.py @@ -0,0 +1,759 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging +from typing import List, Tuple, Union + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.push.baserules import list_with_base_rules +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.databases.main.receipts import ReceiptsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.storage.util.id_generators import ChainedIdGenerator +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +def _load_rules(rawrules, enabled_map): + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = db_to_json(rawrule["conditions"]) + rule["actions"] = db_to_json(rawrule["actions"]) + rule["default"] = False + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + rules = list(list_with_base_rules(ruleslist)) + + for i, rule in enumerate(rules): + rule_id = rule["rule_id"] + if rule_id in enabled_map: + if rule.get("enabled", True) != bool(enabled_map[rule_id]): + # Rules are cached across users. + rule = dict(rule) + rule["enabled"] = bool(enabled_map[rule_id]) + rules[i] = rule + + return rules + + +class PushRulesWorkerStore( + ApplicationServiceWorkerStore, + ReceiptsWorkerStore, + PusherWorkerStore, + RoomMemberWorkerStore, + EventsWorkerStore, + SQLBaseStore, +): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, database: DatabasePool, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) + + if hs.config.worker.worker_app is None: + self._push_rules_stream_id_gen = ChainedIdGenerator( + self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" + ) # type: Union[ChainedIdGenerator, SlavedIdTracker] + else: + self._push_rules_stream_id_gen = SlavedIdTracker( + db_conn, "push_rules_stream", "stream_id" + ) + + push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( + db_conn, + "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", + push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + + @cachedInlineCallbacks(max_entries=5000) + def get_push_rules_for_user(self, user_id): + rows = yield self.db_pool.simple_select_list( + table="push_rules", + keyvalues={"user_name": user_id}, + retcols=( + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ), + desc="get_push_rules_enabled_for_user", + ) + + rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + + enabled_map = yield self.get_push_rules_enabled_for_user(user_id) + + rules = _load_rules(rows, enabled_map) + + return rules + + @cachedInlineCallbacks(max_entries=5000) + def get_push_rules_enabled_for_user(self, user_id): + results = yield self.db_pool.simple_select_list( + table="push_rules_enable", + keyvalues={"user_name": user_id}, + retcols=("user_name", "rule_id", "enabled"), + desc="get_push_rules_enabled_for_user", + ) + return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} + + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + (count,) = txn.fetchone() + return bool(count) + + return self.db_pool.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + + @cachedList( + cached_method_name="get_push_rules_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def bulk_get_push_rules(self, user_ids): + if not user_ids: + return {} + + results = {user_id: [] for user_id in user_ids} + + rows = yield self.db_pool.simple_select_many_batch( + table="push_rules", + column="user_name", + iterable=user_ids, + retcols=("*",), + desc="bulk_get_push_rules", + ) + + rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + + for row in rows: + results.setdefault(row["user_name"], []).append(row) + + enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) + + for user_id, rules in results.items(): + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + + return results + + @defer.inlineCallbacks + def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): + """Copy a single push rule from one room to another for a specific user. + + Args: + new_room_id (str): ID of the new room. + user_id (str): ID of user the push rule belongs to. + rule (Dict): A push rule. + """ + # Create new rule id + rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) + new_rule_id = rule_id_scope + "/" + new_room_id + + # Change room id in each condition + for condition in rule.get("conditions", []): + if condition.get("key") == "room_id": + condition["pattern"] = new_room_id + + # Add the rule for the new room + yield self.add_push_rule( + user_id=user_id, + rule_id=new_rule_id, + priority_class=rule["priority_class"], + conditions=rule["conditions"], + actions=rule["actions"], + ) + + @defer.inlineCallbacks + def copy_push_rules_from_room_to_room_for_user( + self, old_room_id, new_room_id, user_id + ): + """Copy all of the push rules from one room to another for a specific + user. + + Args: + old_room_id (str): ID of the old room. + new_room_id (str): ID of the new room. + user_id (str): ID of user to copy push rules for. + """ + # Retrieve push rules for this user + user_push_rules = yield self.get_push_rules_for_user(user_id) + + # Get rules relating to the old room and copy them to the new room + for rule in user_push_rules: + conditions = rule.get("conditions", []) + if any( + (c.get("key") == "room_id" and c.get("pattern") == old_room_id) + for c in conditions + ): + yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) + + @defer.inlineCallbacks + def bulk_get_push_rules_for_room(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) + result = yield self._bulk_get_push_rules_for_room( + event.room_id, state_group, current_state_ids, event=event + ) + return result + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _bulk_get_push_rules_for_room( + self, room_id, state_group, current_state_ids, cache_context, event=None + ): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + # We also will want to generate notifs for other people in the room so + # their unread countss are correct in the event stream, but to avoid + # generating them for bot / AS users etc, we only do so for people who've + # sent a read receipt into the room. + + users_in_room = yield self._get_joined_users_from_context( + room_id, + state_group, + current_state_ids, + on_invalidate=cache_context.invalidate, + event=event, + ) + + # We ignore app service users for now. This is so that we don't fill + # up the `get_if_users_have_pushers` cache with AS entries that we + # know don't have pushers, nor even read receipts. + local_users_in_room = { + u + for u in users_in_room + if self.hs.is_mine_id(u) + and not self.get_if_app_services_interested_in_user(u) + } + + # users in the room who have pushers need to get push rules run because + # that's how their pushers work + if_users_with_pushers = yield self.get_if_users_have_pushers( + local_users_in_room, on_invalidate=cache_context.invalidate + ) + user_ids = { + uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher + } + + users_with_receipts = yield self.get_users_with_read_receipts_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in local_users_in_room: + user_ids.add(uid) + + rules_by_user = yield self.bulk_get_push_rules( + user_ids, on_invalidate=cache_context.invalidate + ) + + rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} + + return rules_by_user + + @cachedList( + cached_method_name="get_push_rules_enabled_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def bulk_get_push_rules_enabled(self, user_ids): + if not user_ids: + return {} + + results = {user_id: {} for user_id in user_ids} + + rows = yield self.db_pool.simple_select_many_batch( + table="push_rules_enable", + column="user_name", + iterable=user_ids, + retcols=("user_name", "rule_id", "enabled"), + desc="bulk_get_push_rules_enabled", + ) + for row in rows: + enabled = bool(row["enabled"]) + results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled + return results + + async def get_all_push_rule_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for push_rules replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_push_rule_updates_txn(txn): + sql = """ + SELECT stream_id, user_id + FROM push_rules_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + updates = [(stream_id, (user_id,)) for stream_id, user_id in txn] + + limited = False + upper_bound = current_id + if len(updates) == limit: + limited = True + upper_bound = updates[-1][0] + + return updates, upper_bound, limited + + return await self.db_pool.runInteraction( + "get_all_push_rule_updates", get_all_push_rule_updates_txn + ) + + +class PushRuleStore(PushRulesWorkerStore): + @defer.inlineCallbacks + def add_push_rule( + self, + user_id, + rule_id, + priority_class, + conditions, + actions, + before=None, + after=None, + ): + conditions_json = json.dumps(conditions) + actions_json = json.dumps(actions) + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + if before or after: + yield self.db_pool.runInteraction( + "_add_push_rule_relative_txn", + self._add_push_rule_relative_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + before, + after, + ) + else: + yield self.db_pool.runInteraction( + "_add_push_rule_highest_priority_txn", + self._add_push_rule_highest_priority_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + ) + + def _add_push_rule_relative_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + before, + after, + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + + relative_to_rule = before or after + + res = self.db_pool.simple_select_one_txn( + txn, + table="push_rules", + keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, + retcols=["priority_class", "priority"], + allow_none=True, + ) + + if not res: + raise RuleNotFoundException( + "before/after rule not found: %s" % (relative_to_rule,) + ) + + base_priority_class = res["priority_class"] + base_rule_priority = res["priority"] + + if base_priority_class != priority_class: + raise InconsistentRuleException( + "Given priority class does not match class of relative rule" + ) + + if before: + # Higher priority rules are executed first, So adding a rule before + # a rule means giving it a higher priority than that rule. + new_rule_priority = base_rule_priority + 1 + else: + # We increment the priority of the existing rules to make space for + # the new rule. Therefore if we want this rule to appear after + # an existing rule we give it the priority of the existing rule, + # and then increment the priority of the existing rule. + new_rule_priority = base_rule_priority + + sql = ( + "UPDATE push_rules SET priority = priority + 1" + " WHERE user_name = ? AND priority_class = ? AND priority >= ?" + ) + + txn.execute(sql, (user_id, priority_class, new_rule_priority)) + + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + new_rule_priority, + conditions_json, + actions_json, + ) + + def _add_push_rule_highest_priority_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + + # find the highest priority rule in that class + sql = ( + "SELECT COUNT(*), MAX(priority) FROM push_rules" + " WHERE user_name = ? and priority_class = ?" + ) + txn.execute(sql, (user_id, priority_class)) + res = txn.fetchall() + (how_many, highest_prio) = res[0] + + new_prio = 0 + if how_many > 0: + new_prio = highest_prio + 1 + + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + new_prio, + conditions_json, + actions_json, + ) + + def _upsert_push_rule_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + priority, + conditions_json, + actions_json, + update_stream=True, + ): + """Specialised version of simple_upsert_txn that picks a push_rule_id + using the _push_rule_id_gen if it needs to insert the rule. It assumes + that the "push_rules" table is locked""" + + sql = ( + "UPDATE push_rules" + " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" + " WHERE user_name = ? AND rule_id = ?" + ) + + txn.execute( + sql, + (priority_class, priority, conditions_json, actions_json, user_id, rule_id), + ) + + if txn.rowcount == 0: + # We didn't update a row with the given rule_id so insert one + push_rule_id = self._push_rule_id_gen.get_next() + + self.db_pool.simple_insert_txn( + txn, + table="push_rules", + values={ + "id": push_rule_id, + "user_name": user_id, + "rule_id": rule_id, + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) + + if update_stream: + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ADD", + data={ + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) + + @defer.inlineCallbacks + def delete_push_rule(self, user_id, rule_id): + """ + Delete a push rule. Args specify the row to be deleted and can be + any of the columns in the push_rule table, but below are the + standard ones + + Args: + user_id (str): The matrix ID of the push rule owner + rule_id (str): The rule_id of the rule to be deleted + """ + + def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + self.db_pool.simple_delete_one_txn( + txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} + ) + + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.db_pool.runInteraction( + "delete_push_rule", + delete_push_rule_txn, + stream_id, + event_stream_ordering, + ) + + @defer.inlineCallbacks + def set_push_rule_enabled(self, user_id, rule_id, enabled): + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.db_pool.runInteraction( + "_set_push_rule_enabled_txn", + self._set_push_rule_enabled_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + enabled, + ) + + def _set_push_rule_enabled_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled + ): + new_id = self._push_rules_enable_id_gen.get_next() + self.db_pool.simple_upsert_txn( + txn, + "push_rules_enable", + {"user_name": user_id, "rule_id": rule_id}, + {"enabled": 1 if enabled else 0}, + {"id": new_id}, + ) + + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ENABLE" if enabled else "DISABLE", + ) + + @defer.inlineCallbacks + def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + actions_json = json.dumps(actions) + + def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): + if is_default_rule: + # Add a dummy rule to the rules table with the user specified + # actions. + priority_class = -1 + priority = 1 + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + priority, + "[]", + actions_json, + update_stream=False, + ) + else: + self.db_pool.simple_update_one_txn( + txn, + "push_rules", + {"user_name": user_id, "rule_id": rule_id}, + {"actions": actions_json}, + ) + + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ACTIONS", + data={"actions": actions_json}, + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.db_pool.runInteraction( + "set_push_rule_actions", + set_push_rule_actions_txn, + stream_id, + event_stream_ordering, + ) + + def _insert_push_rules_update_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None + ): + values = { + "stream_id": stream_id, + "event_stream_ordering": event_stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": op, + } + if data is not None: + values.update(data) + + self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values) + + txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) + txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) + txn.call_after( + self.push_rules_stream_cache.entity_has_changed, user_id, stream_id + ) + + def get_push_rules_stream_token(self): + """Get the position of the push rules stream. + Returns a pair of a stream id for the push_rules stream and the + room stream ordering it corresponds to.""" + return self._push_rules_stream_id_gen.get_current_token() + + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py new file mode 100644 index 0000000000..b5200fbe79 --- /dev/null +++ b/synapse/storage/databases/main/pusher.py @@ -0,0 +1,356 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Iterable, Iterator, List, Tuple + +from canonicaljson import encode_canonical_json + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList + +logger = logging.getLogger(__name__) + + +class PusherWorkerStore(SQLBaseStore): + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: + """JSON-decode the data in the rows returned from the `pushers` table + + Drops any rows whose data cannot be decoded + """ + for r in rows: + dataJson = r["data"] + try: + r["data"] = db_to_json(dataJson) + except Exception as e: + logger.warning( + "Invalid JSON in data for pusher %d: %s, %s", + r["id"], + dataJson, + e.args[0], + ) + continue + + yield r + + @defer.inlineCallbacks + def user_has_pusher(self, user_id): + ret = yield self.db_pool.simple_select_one_onecol( + "pushers", {"user_name": user_id}, "id", allow_none=True + ) + return ret is not None + + def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): + return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) + + def get_pushers_by_user_id(self, user_id): + return self.get_pushers_by({"user_name": user_id}) + + @defer.inlineCallbacks + def get_pushers_by(self, keyvalues): + ret = yield self.db_pool.simple_select_list( + "pushers", + keyvalues, + [ + "id", + "user_name", + "access_token", + "profile_tag", + "kind", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "ts", + "lang", + "data", + "last_stream_ordering", + "last_success", + "failing_since", + ], + desc="get_pushers_by", + ) + return self._decode_pushers_rows(ret) + + @defer.inlineCallbacks + def get_all_pushers(self): + def get_pushers(txn): + txn.execute("SELECT * FROM pushers") + rows = self.db_pool.cursor_to_dict(txn) + + return self._decode_pushers_rows(rows) + + rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers) + return rows + + async def get_all_updated_pushers_rows( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for pushers replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_updated_pushers_rows_txn(txn): + sql = """ + SELECT id, user_name, app_id, pushkey + FROM pushers + WHERE ? < id AND id <= ? + ORDER BY id ASC LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + updates = [ + (stream_id, (user_name, app_id, pushkey, False)) + for stream_id, user_name, app_id, pushkey in txn + ] + + sql = """ + SELECT stream_id, user_id, app_id, pushkey + FROM deleted_pushers + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + updates.extend( + (stream_id, (user_name, app_id, pushkey, True)) + for stream_id, user_name, app_id, pushkey in txn + ) + + updates.sort() # Sort so that they're ordered by stream id + + limited = False + upper_bound = current_id + if len(updates) >= limit: + limited = True + upper_bound = updates[-1][0] + + return updates, upper_bound, limited + + return await self.db_pool.runInteraction( + "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn + ) + + @cachedInlineCallbacks(num_args=1, max_entries=15000) + def get_if_user_has_pusher(self, user_id): + # This only exists for the cachedList decorator + raise NotImplementedError() + + @cachedList( + cached_method_name="get_if_user_has_pusher", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def get_if_users_have_pushers(self, user_ids): + rows = yield self.db_pool.simple_select_many_batch( + table="pushers", + column="user_name", + iterable=user_ids, + retcols=["user_name"], + desc="get_if_users_have_pushers", + ) + + result = {user_id: False for user_id in user_ids} + result.update({r["user_name"]: True for r in rows}) + + return result + + @defer.inlineCallbacks + def update_pusher_last_stream_ordering( + self, app_id, pushkey, user_id, last_stream_ordering + ): + yield self.db_pool.simple_update_one( + "pushers", + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"last_stream_ordering": last_stream_ordering}, + desc="update_pusher_last_stream_ordering", + ) + + @defer.inlineCallbacks + def update_pusher_last_stream_ordering_and_success( + self, app_id, pushkey, user_id, last_stream_ordering, last_success + ): + """Update the last stream ordering position we've processed up to for + the given pusher. + + Args: + app_id (str) + pushkey (str) + last_stream_ordering (int) + last_success (int) + + Returns: + Deferred[bool]: True if the pusher still exists; False if it has been deleted. + """ + updated = yield self.db_pool.simple_update( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + updatevalues={ + "last_stream_ordering": last_stream_ordering, + "last_success": last_success, + }, + desc="update_pusher_last_stream_ordering_and_success", + ) + + return bool(updated) + + @defer.inlineCallbacks + def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): + yield self.db_pool.simple_update( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + updatevalues={"failing_since": failing_since}, + desc="update_pusher_failing_since", + ) + + @defer.inlineCallbacks + def get_throttle_params_by_room(self, pusher_id): + res = yield self.db_pool.simple_select_list( + "pusher_throttle", + {"pusher": pusher_id}, + ["room_id", "last_sent_ts", "throttle_ms"], + desc="get_throttle_params_by_room", + ) + + params_by_room = {} + for row in res: + params_by_room[row["room_id"]] = { + "last_sent_ts": row["last_sent_ts"], + "throttle_ms": row["throttle_ms"], + } + + return params_by_room + + @defer.inlineCallbacks + def set_throttle_params(self, pusher_id, room_id, params): + # no need to lock because `pusher_throttle` has a primary key on + # (pusher, room_id) so simple_upsert will retry + yield self.db_pool.simple_upsert( + "pusher_throttle", + {"pusher": pusher_id, "room_id": room_id}, + params, + desc="set_throttle_params", + lock=False, + ) + + +class PusherStore(PusherWorkerStore): + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + + @defer.inlineCallbacks + def add_pusher( + self, + user_id, + access_token, + kind, + app_id, + app_display_name, + device_display_name, + pushkey, + pushkey_ts, + lang, + data, + last_stream_ordering, + profile_tag="", + ): + with self._pushers_id_gen.get_next() as stream_id: + # no need to lock because `pushers` has a unique key on + # (app_id, pushkey, user_name) so simple_upsert will retry + yield self.db_pool.simple_upsert( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + values={ + "access_token": access_token, + "kind": kind, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "ts": pushkey_ts, + "lang": lang, + "data": bytearray(encode_canonical_json(data)), + "last_stream_ordering": last_stream_ordering, + "profile_tag": profile_tag, + "id": stream_id, + }, + desc="add_pusher", + lock=False, + ) + + user_has_pusher = self.get_if_user_has_pusher.cache.get( + (user_id,), None, update_metrics=False + ) + + if user_has_pusher is not True: + # invalidate, since we the user might not have had a pusher before + yield self.db_pool.runInteraction( + "add_pusher", + self._invalidate_cache_and_stream, + self.get_if_user_has_pusher, + (user_id,), + ) + + @defer.inlineCallbacks + def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): + def delete_pusher_txn(txn, stream_id): + self._invalidate_cache_and_stream( + txn, self.get_if_user_has_pusher, (user_id,) + ) + + self.db_pool.simple_delete_one_txn( + txn, + "pushers", + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + ) + + # it's possible for us to end up with duplicate rows for + # (app_id, pushkey, user_id) at different stream_ids, but that + # doesn't really matter. + self.db_pool.simple_insert_txn( + txn, + table="deleted_pushers", + values={ + "stream_id": stream_id, + "app_id": app_id, + "pushkey": pushkey, + "user_id": user_id, + }, + ) + + with self._pushers_id_gen.get_next() as stream_id: + yield self.db_pool.runInteraction( + "delete_pusher", delete_pusher_txn, stream_id + ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py new file mode 100644 index 0000000000..6255977c92 --- /dev/null +++ b/synapse/storage/databases/main/receipts.py @@ -0,0 +1,591 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging +from typing import List, Tuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +class ReceiptsWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_receipt_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, database: DatabasePool, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) + + self._receipts_stream_cache = StreamChangeCache( + "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() + ) + + @abc.abstractmethod + def get_max_receipt_stream_id(self): + """Get the current max stream ID for receipts stream + + Returns: + int + """ + raise NotImplementedError() + + @cachedInlineCallbacks() + def get_users_with_read_receipts_in_room(self, room_id): + receipts = yield self.get_receipts_for_room(room_id, "m.read") + return {r["user_id"] for r in receipts} + + @cached(num_args=2) + def get_receipts_for_room(self, room_id, receipt_type): + return self.db_pool.simple_select_list( + table="receipts_linearized", + keyvalues={"room_id": room_id, "receipt_type": receipt_type}, + retcols=("user_id", "event_id"), + desc="get_receipts_for_room", + ) + + @cached(num_args=3) + def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): + return self.db_pool.simple_select_one_onecol( + table="receipts_linearized", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + }, + retcol="event_id", + desc="get_own_receipt_for_user", + allow_none=True, + ) + + @cachedInlineCallbacks(num_args=2) + def get_receipts_for_user(self, user_id, receipt_type): + rows = yield self.db_pool.simple_select_list( + table="receipts_linearized", + keyvalues={"user_id": user_id, "receipt_type": receipt_type}, + retcols=("room_id", "event_id"), + desc="get_receipts_for_user", + ) + + return {row["room_id"]: row["event_id"] for row in rows} + + @defer.inlineCallbacks + def get_receipts_for_user_with_orderings(self, user_id, receipt_type): + def f(txn): + sql = ( + "SELECT rl.room_id, rl.event_id," + " e.topological_ordering, e.stream_ordering" + " FROM receipts_linearized AS rl" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE rl.room_id = e.room_id" + " AND rl.event_id = e.event_id" + " AND user_id = ?" + ) + txn.execute(sql, (user_id,)) + return txn.fetchall() + + rows = yield self.db_pool.runInteraction( + "get_receipts_for_user_with_orderings", f + ) + return { + row[0]: { + "event_id": row[1], + "topological_ordering": row[2], + "stream_ordering": row[3], + } + for row in rows + } + + @defer.inlineCallbacks + def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + """Get receipts for multiple rooms for sending to clients. + + Args: + room_ids (list): List of room_ids. + to_key (int): Max stream id to fetch receipts upto. + from_key (int): Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + list: A list of receipts. + """ + room_ids = set(room_ids) + + if from_key is not None: + # Only ask the database about rooms where there have been new + # receipts added since `from_key` + room_ids = yield self._receipts_stream_cache.get_entities_changed( + room_ids, from_key + ) + + results = yield self._get_linearized_receipts_for_rooms( + room_ids, to_key, from_key=from_key + ) + + return [ev for res in results.values() for ev in res] + + def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + """Get receipts for a single room for sending to clients. + + Args: + room_ids (str): The room id. + to_key (int): Max stream id to fetch receipts upto. + from_key (int): Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + Deferred[list]: A list of receipts. + """ + if from_key is not None: + # Check the cache first to see if any new receipts have been added + # since`from_key`. If not we can no-op. + if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): + defer.succeed([]) + + return self._get_linearized_receipts_for_room(room_id, to_key, from_key) + + @cachedInlineCallbacks(num_args=3, tree=True) + def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + """See get_linearized_receipts_for_room + """ + + def f(txn): + if from_key: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id = ? AND stream_id > ? AND stream_id <= ?" + ) + + txn.execute(sql, (room_id, from_key, to_key)) + else: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id = ? AND stream_id <= ?" + ) + + txn.execute(sql, (room_id, to_key)) + + rows = self.db_pool.cursor_to_dict(txn) + + return rows + + rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f) + + if not rows: + return [] + + content = {} + for row in rows: + content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ + row["user_id"] + ] = db_to_json(row["data"]) + + return [{"type": "m.receipt", "room_id": room_id, "content": content}] + + @cachedList( + cached_method_name="_get_linearized_receipts_for_room", + list_name="room_ids", + num_args=3, + inlineCallbacks=True, + ) + def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + if not room_ids: + return {} + + def f(txn): + if from_key: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id > ? AND stream_id <= ? AND + """ + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + txn.execute(sql + clause, [from_key, to_key] + list(args)) + else: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id <= ? AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + txn.execute(sql + clause, [to_key] + list(args)) + + return self.db_pool.cursor_to_dict(txn) + + txn_results = yield self.db_pool.runInteraction( + "_get_linearized_receipts_for_rooms", f + ) + + results = {} + for row in txn_results: + # We want a single event per room, since we want to batch the + # receipts by room, event and type. + room_event = results.setdefault( + row["room_id"], + {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + ) + + # The content is of the form: + # {"$foo:bar": { "read": { "@user:host": }, .. }, .. } + event_entry = room_event["content"].setdefault(row["event_id"], {}) + receipt_type = event_entry.setdefault(row["receipt_type"], {}) + + receipt_type[row["user_id"]] = db_to_json(row["data"]) + + results = { + room_id: [results[room_id]] if room_id in results else [] + for room_id in room_ids + } + return results + + def get_users_sent_receipts_between(self, last_id: int, current_id: int): + """Get all users who sent receipts between `last_id` exclusive and + `current_id` inclusive. + + Returns: + Deferred[List[str]] + """ + + if last_id == current_id: + return defer.succeed([]) + + def _get_users_sent_receipts_between_txn(txn): + sql = """ + SELECT DISTINCT user_id FROM receipts_linearized + WHERE ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (last_id, current_id)) + + return [r[0] for r in txn] + + return self.db_pool.runInteraction( + "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn + ) + + async def get_all_updated_receipts( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for receipts replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_updated_receipts_txn(txn): + sql = """ + SELECT stream_id, room_id, receipt_type, user_id, event_id, data + FROM receipts_linearized + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + + updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] + + limited = False + upper_bound = current_id + + if len(updates) == limit: + limited = True + upper_bound = updates[-1][0] + + return updates, upper_bound, limited + + return await self.db_pool.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) + + def _invalidate_get_users_with_receipts_in_room( + self, room_id, receipt_type, user_id + ): + if receipt_type != "m.read": + return + + # Returns either an ObservableDeferred or the raw result + res = self.get_users_with_read_receipts_in_room.cache.get( + room_id, None, update_metrics=False + ) + + # first handle the ObservableDeferred case + if isinstance(res, ObservableDeferred): + if res.has_called(): + res = res.get_result() + else: + res = None + + if res and user_id in res: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return + + self.get_users_with_read_receipts_in_room.invalidate((room_id,)) + + +class ReceiptsStore(ReceiptsWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + + super(ReceiptsStore, self).__init__(database, db_conn, hs) + + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_current_token() + + def insert_linearized_receipt_txn( + self, txn, room_id, receipt_type, user_id, event_id, data, stream_id + ): + """Inserts a read-receipt into the database if it's newer than the current RR + + Returns: int|None + None if the RR is older than the current RR + otherwise, the rx timestamp of the event that the RR corresponds to + (or 0 if the event is unknown) + """ + res = self.db_pool.simple_select_one_txn( + txn, + table="events", + retcols=["stream_ordering", "received_ts"], + keyvalues={"event_id": event_id}, + allow_none=True, + ) + + stream_ordering = int(res["stream_ordering"]) if res else None + rx_ts = res["received_ts"] if res else 0 + + # We don't want to clobber receipts for more recent events, so we + # have to compare orderings of existing receipts + if stream_ordering is not None: + sql = ( + "SELECT stream_ordering, event_id FROM events" + " INNER JOIN receipts_linearized as r USING (event_id, room_id)" + " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + ) + txn.execute(sql, (room_id, receipt_type, user_id)) + + for so, eid in txn: + if int(so) >= stream_ordering: + logger.debug( + "Ignoring new receipt for %s in favour of existing " + "one for later event %s", + event_id, + eid, + ) + return None + + txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) + txn.call_after( + self._invalidate_get_users_with_receipts_in_room, + room_id, + receipt_type, + user_id, + ) + txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) + # FIXME: This shouldn't invalidate the whole cache + txn.call_after( + self._get_linearized_receipts_for_room.invalidate_many, (room_id,) + ) + + txn.call_after( + self._receipts_stream_cache.entity_has_changed, room_id, stream_id + ) + + txn.call_after( + self.get_last_receipt_event_id_for_user.invalidate, + (user_id, room_id, receipt_type), + ) + + self.db_pool.simple_upsert_txn( + txn, + table="receipts_linearized", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + "event_id": event_id, + "data": json.dumps(data), + }, + # receipts_linearized has a unique constraint on + # (user_id, room_id, receipt_type), so no need to lock + lock=False, + ) + + if receipt_type == "m.read" and stream_ordering is not None: + self._remove_old_push_actions_before_txn( + txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering + ) + + return rx_ts + + @defer.inlineCallbacks + def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): + """Insert a receipt, either from local client or remote server. + + Automatically does conversion between linearized and graph + representations. + """ + if not event_ids: + return + + if len(event_ids) == 1: + linearized_event_id = event_ids[0] + else: + # we need to points in graph -> linearized form. + # TODO: Make this better. + def graph_to_linear(txn): + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", event_ids + ) + + sql = """ + SELECT event_id WHERE room_id = ? AND stream_ordering IN ( + SELECT max(stream_ordering) WHERE %s + ) + """ % ( + clause, + ) + + txn.execute(sql, [room_id] + list(args)) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) + + linearized_event_id = yield self.db_pool.runInteraction( + "insert_receipt_conv", graph_to_linear + ) + + stream_id_manager = self._receipts_id_gen.get_next() + with stream_id_manager as stream_id: + event_ts = yield self.db_pool.runInteraction( + "insert_linearized_receipt", + self.insert_linearized_receipt_txn, + room_id, + receipt_type, + user_id, + linearized_event_id, + data, + stream_id=stream_id, + ) + + if event_ts is None: + return None + + now = self._clock.time_msec() + logger.debug( + "RR for event %s in %s (%i ms old)", + linearized_event_id, + room_id, + now - event_ts, + ) + + yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) + + max_persisted_id = self._receipts_id_gen.get_current_token() + + return stream_id, max_persisted_id + + def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): + return self.db_pool.runInteraction( + "insert_graph_receipt", + self.insert_graph_receipt_txn, + room_id, + receipt_type, + user_id, + event_ids, + data, + ) + + def insert_graph_receipt_txn( + self, txn, room_id, receipt_type, user_id, event_ids, data + ): + txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) + txn.call_after( + self._invalidate_get_users_with_receipts_in_room, + room_id, + receipt_type, + user_id, + ) + txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) + # FIXME: This shouldn't invalidate the whole cache + txn.call_after( + self._get_linearized_receipts_for_room.invalidate_many, (room_id,) + ) + + self.db_pool.simple_delete_txn( + txn, + table="receipts_graph", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + }, + ) + self.db_pool.simple_insert_txn( + txn, + table="receipts_graph", + values={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_ids": json.dumps(event_ids), + "data": json.dumps(data), + }, + ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py new file mode 100644 index 0000000000..f618629e09 --- /dev/null +++ b/synapse/storage/databases/main/registration.py @@ -0,0 +1,1588 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import Optional + +from twisted.internet import defer +from twisted.internet.defer import Deferred + +from synapse.api.constants import UserTypes +from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator +from synapse.types import UserID +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 + +logger = logging.getLogger(__name__) + + +class RegistrationWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) + + self.config = hs.config + self.clock = hs.get_clock() + + self._user_id_seq = build_sequence_generator( + database.engine, find_max_generated_user_id_localpart, "user_id_seq", + ) + + @cached() + def get_user_by_id(self, user_id): + return self.db_pool.simple_select_one( + table="users", + keyvalues={"name": user_id}, + retcols=[ + "name", + "password_hash", + "is_guest", + "admin", + "consent_version", + "consent_server_notice_sent", + "appservice_id", + "creation_ts", + "user_type", + "deactivated", + ], + allow_none=True, + desc="get_user_by_id", + ) + + @defer.inlineCallbacks + def is_trial_user(self, user_id): + """Checks if user is in the "trial" period, i.e. within the first + N days of registration defined by `mau_trial_days` config + + Args: + user_id (str) + + Returns: + Deferred[bool] + """ + + info = yield self.get_user_by_id(user_id) + if not info: + return False + + now = self.clock.time_msec() + trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms + return is_trial + + @cached() + def get_user_by_access_token(self, token): + """Get a user from the given access token. + + Args: + token (str): The access token of a user. + Returns: + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`, + `valid_until_ms`. + """ + return self.db_pool.runInteraction( + "get_user_by_access_token", self._query_for_auth, token + ) + + @cachedInlineCallbacks() + def get_expiration_ts_for_user(self, user_id): + """Get the expiration timestamp for the account bearing a given user ID. + + Args: + user_id (str): The ID of the user. + Returns: + defer.Deferred: None, if the account has no expiration timestamp, + otherwise int representation of the timestamp (as a number of + milliseconds since epoch). + """ + res = yield self.db_pool.simple_select_one_onecol( + table="account_validity", + keyvalues={"user_id": user_id}, + retcol="expiration_ts_ms", + allow_none=True, + desc="get_expiration_ts_for_user", + ) + return res + + @defer.inlineCallbacks + def set_account_validity_for_user( + self, user_id, expiration_ts, email_sent, renewal_token=None + ): + """Updates the account validity properties of the given account, with the + given values. + + Args: + user_id (str): ID of the account to update properties for. + expiration_ts (int): New expiration date, as a timestamp in milliseconds + since epoch. + email_sent (bool): True means a renewal email has been sent for this + account and there's no need to send another one for the current validity + period. + renewal_token (str): Renewal token the user can use to extend the validity + of their account. Defaults to no token. + """ + + def set_account_validity_for_user_txn(txn): + self.db_pool.simple_update_txn( + txn=txn, + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={ + "expiration_ts_ms": expiration_ts, + "email_sent": email_sent, + "renewal_token": renewal_token, + }, + ) + self._invalidate_cache_and_stream( + txn, self.get_expiration_ts_for_user, (user_id,) + ) + + yield self.db_pool.runInteraction( + "set_account_validity_for_user", set_account_validity_for_user_txn + ) + + @defer.inlineCallbacks + def set_renewal_token_for_user(self, user_id, renewal_token): + """Defines a renewal token for a given user. + + Args: + user_id (str): ID of the user to set the renewal token for. + renewal_token (str): Random unique string that will be used to renew the + user's account. + + Raises: + StoreError: The provided token is already set for another user. + """ + yield self.db_pool.simple_update_one( + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={"renewal_token": renewal_token}, + desc="set_renewal_token_for_user", + ) + + @defer.inlineCallbacks + def get_user_from_renewal_token(self, renewal_token): + """Get a user ID from a renewal token. + + Args: + renewal_token (str): The renewal token to perform the lookup with. + + Returns: + defer.Deferred[str]: The ID of the user to which the token belongs. + """ + res = yield self.db_pool.simple_select_one_onecol( + table="account_validity", + keyvalues={"renewal_token": renewal_token}, + retcol="user_id", + desc="get_user_from_renewal_token", + ) + + return res + + @defer.inlineCallbacks + def get_renewal_token_for_user(self, user_id): + """Get the renewal token associated with a given user ID. + + Args: + user_id (str): The user ID to lookup a token for. + + Returns: + defer.Deferred[str]: The renewal token associated with this user ID. + """ + res = yield self.db_pool.simple_select_one_onecol( + table="account_validity", + keyvalues={"user_id": user_id}, + retcol="renewal_token", + desc="get_renewal_token_for_user", + ) + + return res + + @defer.inlineCallbacks + def get_users_expiring_soon(self): + """Selects users whose account will expire in the [now, now + renew_at] time + window (see configuration for account_validity for information on what renew_at + refers to). + + Returns: + Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] + """ + + def select_users_txn(txn, now_ms, renew_at): + sql = ( + "SELECT user_id, expiration_ts_ms FROM account_validity" + " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" + ) + values = [False, now_ms, renew_at] + txn.execute(sql, values) + return self.db_pool.cursor_to_dict(txn) + + res = yield self.db_pool.runInteraction( + "get_users_expiring_soon", + select_users_txn, + self.clock.time_msec(), + self.config.account_validity.renew_at, + ) + + return res + + @defer.inlineCallbacks + def set_renewal_mail_status(self, user_id, email_sent): + """Sets or unsets the flag that indicates whether a renewal email has been sent + to the user (and the user hasn't renewed their account yet). + + Args: + user_id (str): ID of the user to set/unset the flag for. + email_sent (bool): Flag which indicates whether a renewal email has been sent + to this user. + """ + yield self.db_pool.simple_update_one( + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={"email_sent": email_sent}, + desc="set_renewal_mail_status", + ) + + @defer.inlineCallbacks + def delete_account_validity_for_user(self, user_id): + """Deletes the entry for the given user in the account validity table, removing + their expiration date and renewal token. + + Args: + user_id (str): ID of the user to remove from the account validity table. + """ + yield self.db_pool.simple_delete_one( + table="account_validity", + keyvalues={"user_id": user_id}, + desc="delete_account_validity_for_user", + ) + + async def is_server_admin(self, user): + """Determines if a user is an admin of this homeserver. + + Args: + user (UserID): user ID of the user to test + + Returns (bool): + true iff the user is a server admin, false otherwise. + """ + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user.to_string()}, + retcol="admin", + allow_none=True, + desc="is_server_admin", + ) + + return bool(res) if res else False + + def set_server_admin(self, user, admin): + """Sets whether a user is an admin of this homeserver. + + Args: + user (UserID): user ID of the user to test + admin (bool): true iff the user is to be a server admin, + false otherwise. + """ + + def set_server_admin_txn(txn): + self.db_pool.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + + def _query_for_auth(self, txn, token): + sql = ( + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id, access_tokens.valid_until_ms" + " FROM users" + " INNER JOIN access_tokens on users.name = access_tokens.user_id" + " WHERE token = ?" + ) + + txn.execute(sql, (token,)) + rows = self.db_pool.cursor_to_dict(txn) + if rows: + return rows[0] + + return None + + @cachedInlineCallbacks() + def is_real_user(self, user_id): + """Determines if the user is a real user, ie does not have a 'user_type'. + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user 'user_type' is null or empty string + """ + res = yield self.db_pool.runInteraction( + "is_real_user", self.is_real_user_txn, user_id + ) + return res + + @cached() + def is_support_user(self, user_id): + """Determines if the user is of type UserTypes.SUPPORT + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user is of type UserTypes.SUPPORT + """ + return self.db_pool.runInteraction( + "is_support_user", self.is_support_user_txn, user_id + ) + + def is_real_user_txn(self, txn, user_id): + res = self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return res is None + + def is_support_user_txn(self, txn, user_id): + res = self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return True if res == UserTypes.SUPPORT else False + + def get_users_by_id_case_insensitive(self, user_id): + """Gets users that match user_id case insensitively. + Returns a mapping of user_id -> password_hash. + """ + + def f(txn): + sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)" + txn.execute(sql, (user_id,)) + return dict(txn) + + return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) + + async def get_user_by_external_id( + self, auth_provider: str, external_id: str + ) -> str: + """Look up a user by their external auth id + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + + Returns: + str|None: the mxid of the user, or None if they are not known + """ + return await self.db_pool.simple_select_one_onecol( + table="user_external_ids", + keyvalues={"auth_provider": auth_provider, "external_id": external_id}, + retcol="user_id", + allow_none=True, + desc="get_user_by_external_id", + ) + + @defer.inlineCallbacks + def count_all_users(self): + """Counts all users registered on the homeserver.""" + + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users") + rows = self.db_pool.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.db_pool.runInteraction("count_users", _count_users) + return ret + + def count_daily_user_type(self): + """ + Counts 1) native non guest users + 2) native guests users + 3) bridged users + who registered on the homeserver in the past 24 hours + """ + + def _count_daily_user_type(txn): + yesterday = int(self._clock.time()) - (60 * 60 * 24) + + sql = """ + SELECT user_type, COALESCE(count(*), 0) AS count FROM ( + SELECT + CASE + WHEN is_guest=0 AND appservice_id IS NULL THEN 'native' + WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest' + WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged' + END AS user_type + FROM users + WHERE creation_ts > ? + ) AS t GROUP BY user_type + """ + results = {"native": 0, "guest": 0, "bridged": 0} + txn.execute(sql, (yesterday,)) + for row in txn: + results[row[0]] = row[1] + return results + + return self.db_pool.runInteraction( + "count_daily_user_type", _count_daily_user_type + ) + + @defer.inlineCallbacks + def count_nonbridged_users(self): + def _count_users(txn): + txn.execute( + """ + SELECT COALESCE(COUNT(*), 0) FROM users + WHERE appservice_id IS NULL + """ + ) + (count,) = txn.fetchone() + return count + + ret = yield self.db_pool.runInteraction("count_users", _count_users) + return ret + + @defer.inlineCallbacks + def count_real_users(self): + """Counts all users without a special user_type registered on the homeserver.""" + + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") + rows = self.db_pool.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.db_pool.runInteraction("count_real_users", _count_users) + return ret + + async def generate_user_id(self) -> str: + """Generate a suitable localpart for a guest user + + Returns: a (hopefully) free localpart + """ + next_id = await self.db_pool.runInteraction( + "generate_user_id", self._user_id_seq.get_next_id_txn + ) + + return str(next_id) + + async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: + """Returns user id from threepid + + Args: + medium: threepid medium e.g. email + address: threepid address e.g. me@example.com + + Returns: + The user ID or None if no user id/threepid mapping exists + """ + user_id = await self.db_pool.runInteraction( + "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address + ) + return user_id + + def get_user_id_by_threepid_txn(self, txn, medium, address): + """Returns user id from threepid + + Args: + txn (cursor): + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + str|None: user id or None if no user id/threepid mapping exists + """ + ret = self.db_pool.simple_select_one_txn( + txn, + "user_threepids", + {"medium": medium, "address": address}, + ["user_id"], + True, + ) + if ret: + return ret["user_id"] + return None + + @defer.inlineCallbacks + def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + yield self.db_pool.simple_upsert( + "user_threepids", + {"medium": medium, "address": address}, + {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, + ) + + @defer.inlineCallbacks + def user_get_threepids(self, user_id): + ret = yield self.db_pool.simple_select_list( + "user_threepids", + {"user_id": user_id}, + ["medium", "address", "validated_at", "added_at"], + "user_get_threepids", + ) + return ret + + def user_delete_threepid(self, user_id, medium, address): + return self.db_pool.simple_delete( + "user_threepids", + keyvalues={"user_id": user_id, "medium": medium, "address": address}, + desc="user_delete_threepid", + ) + + def user_delete_threepids(self, user_id: str): + """Delete all threepid this user has bound + + Args: + user_id: The user id to delete all threepids of + + """ + return self.db_pool.simple_delete( + "user_threepids", + keyvalues={"user_id": user_id}, + desc="user_delete_threepids", + ) + + def add_user_bound_threepid(self, user_id, medium, address, id_server): + """The server proxied a bind request to the given identity server on + behalf of the given user. We need to remember this in case the user + asks us to unbind the threepid. + + Args: + user_id (str) + medium (str) + address (str) + id_server (str) + + Returns: + Deferred + """ + # We need to use an upsert, in case they user had already bound the + # threepid + return self.db_pool.simple_upsert( + table="user_threepid_id_server", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + "id_server": id_server, + }, + values={}, + insertion_values={}, + desc="add_user_bound_threepid", + ) + + def user_get_bound_threepids(self, user_id): + """Get the threepids that a user has bound to an identity server through the homeserver + The homeserver remembers where binds to an identity server occurred. Using this + method can retrieve those threepids. + + Args: + user_id (str): The ID of the user to retrieve threepids for + + Returns: + Deferred[list[dict]]: List of dictionaries containing the following: + medium (str): The medium of the threepid (e.g "email") + address (str): The address of the threepid (e.g "bob@example.com") + """ + return self.db_pool.simple_select_list( + table="user_threepid_id_server", + keyvalues={"user_id": user_id}, + retcols=["medium", "address"], + desc="user_get_bound_threepids", + ) + + def remove_user_bound_threepid(self, user_id, medium, address, id_server): + """The server proxied an unbind request to the given identity server on + behalf of the given user, so we remove the mapping of threepid to + identity server. + + Args: + user_id (str) + medium (str) + address (str) + id_server (str) + + Returns: + Deferred + """ + return self.db_pool.simple_delete( + table="user_threepid_id_server", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + "id_server": id_server, + }, + desc="remove_user_bound_threepid", + ) + + def get_id_servers_user_bound(self, user_id, medium, address): + """Get the list of identity servers that the server proxied bind + requests to for given user and threepid + + Args: + user_id (str) + medium (str) + address (str) + + Returns: + Deferred[list[str]]: Resolves to a list of identity servers + """ + return self.db_pool.simple_select_onecol( + table="user_threepid_id_server", + keyvalues={"user_id": user_id, "medium": medium, "address": address}, + retcol="id_server", + desc="get_id_servers_user_bound", + ) + + @cachedInlineCallbacks() + def get_user_deactivated_status(self, user_id): + """Retrieve the value for the `deactivated` property for the provided user. + + Args: + user_id (str): The ID of the user to retrieve the status for. + + Returns: + defer.Deferred(bool): The requested value. + """ + + res = yield self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="deactivated", + desc="get_user_deactivated_status", + ) + + # Convert the integer into a boolean. + return res == 1 + + def get_threepid_validation_session( + self, medium, client_secret, address=None, sid=None, validated=True + ): + """Gets a session_id and last_send_attempt (if available) for a + combination of validation metadata + + Args: + medium (str|None): The medium of the 3PID + address (str|None): The address of the 3PID + sid (str|None): The ID of the validation session + client_secret (str): A unique string provided by the client to help identify this + validation attempt + validated (bool|None): Whether sessions should be filtered by + whether they have been validated already or not. None to + perform no filtering + + Returns: + Deferred[dict|None]: A dict containing the following: + * address - address of the 3pid + * medium - medium of the 3pid + * client_secret - a secret provided by the client for this validation session + * session_id - ID of the validation session + * send_attempt - a number serving to dedupe send attempts for this session + * validated_at - timestamp of when this session was validated if so + + Otherwise None if a validation session is not found + """ + if not client_secret: + raise SynapseError( + 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM + ) + + keyvalues = {"client_secret": client_secret} + if medium: + keyvalues["medium"] = medium + if address: + keyvalues["address"] = address + if sid: + keyvalues["session_id"] = sid + + assert address or sid + + def get_threepid_validation_session_txn(txn): + sql = """ + SELECT address, session_id, medium, client_secret, + last_send_attempt, validated_at + FROM threepid_validation_session WHERE %s + """ % ( + " AND ".join("%s = ?" % k for k in keyvalues.keys()), + ) + + if validated is not None: + sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") + + sql += " LIMIT 1" + + txn.execute(sql, list(keyvalues.values())) + rows = self.db_pool.cursor_to_dict(txn) + if not rows: + return None + + return rows[0] + + return self.db_pool.runInteraction( + "get_threepid_validation_session", get_threepid_validation_session_txn + ) + + def delete_threepid_session(self, session_id): + """Removes a threepid validation session from the database. This can + be done after validation has been performed and whatever action was + waiting on it has been carried out + + Args: + session_id (str): The ID of the session to delete + """ + + def delete_threepid_session_txn(txn): + self.db_pool.simple_delete_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + ) + + return self.db_pool.runInteraction( + "delete_threepid_session", delete_threepid_session_txn + ) + + +class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.clock = hs.get_clock() + self.config = hs.config + + self.db_pool.updates.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.db_pool.updates.register_background_index_update( + "users_creation_ts", + index_name="users_creation_ts", + table="users", + columns=["creation_ts"], + ) + + # we no longer use refresh tokens, but it's possible that some people + # might have a background update queued to build this index. Just + # clear the background update. + self.db_pool.updates.register_noop_background_update( + "refresh_tokens_device_index" + ) + + self.db_pool.updates.register_background_update_handler( + "user_threepids_grandfather", self._bg_user_threepids_grandfather + ) + + self.db_pool.updates.register_background_update_handler( + "users_set_deactivated_flag", self._background_update_set_deactivated_flag + ) + + @defer.inlineCallbacks + def _background_update_set_deactivated_flag(self, progress, batch_size): + """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 + for each of them. + """ + + last_user = progress.get("user_id", "") + + def _background_update_set_deactivated_flag_txn(txn): + txn.execute( + """ + SELECT + users.name, + COUNT(access_tokens.token) AS count_tokens, + COUNT(user_threepids.address) AS count_threepids + FROM users + LEFT JOIN access_tokens ON (access_tokens.user_id = users.name) + LEFT JOIN user_threepids ON (user_threepids.user_id = users.name) + WHERE (users.password_hash IS NULL OR users.password_hash = '') + AND (users.appservice_id IS NULL OR users.appservice_id = '') + AND users.is_guest = 0 + AND users.name > ? + GROUP BY users.name + ORDER BY users.name ASC + LIMIT ?; + """, + (last_user, batch_size), + ) + + rows = self.db_pool.cursor_to_dict(txn) + + if not rows: + return True, 0 + + rows_processed_nb = 0 + + for user in rows: + if not user["count_tokens"] and not user["count_threepids"]: + self.set_user_deactivated_status_txn(txn, user["name"], True) + rows_processed_nb += 1 + + logger.info("Marked %d rows as deactivated", rows_processed_nb) + + self.db_pool.updates._background_update_progress_txn( + txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} + ) + + if batch_size > len(rows): + return True, len(rows) + else: + return False, len(rows) + + end, nb_processed = yield self.db_pool.runInteraction( + "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn + ) + + if end: + yield self.db_pool.updates._end_background_update( + "users_set_deactivated_flag" + ) + + return nb_processed + + @defer.inlineCallbacks + def _bg_user_threepids_grandfather(self, progress, batch_size): + """We now track which identity servers a user binds their 3PID to, so + we need to handle the case of existing bindings where we didn't track + this. + + We do this by grandfathering in existing user threepids assuming that + they used one of the server configured trusted identity servers. + """ + id_servers = set(self.config.trusted_third_party_id_servers) + + def _bg_user_threepids_grandfather_txn(txn): + sql = """ + INSERT INTO user_threepid_id_server + (user_id, medium, address, id_server) + SELECT user_id, medium, address, ? + FROM user_threepids + """ + + txn.executemany(sql, [(id_server,) for id_server in id_servers]) + + if id_servers: + yield self.db_pool.runInteraction( + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn + ) + + yield self.db_pool.updates._end_background_update("user_threepids_grandfather") + + return 1 + + +class RegistrationStore(RegistrationBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RegistrationStore, self).__init__(database, db_conn, hs) + + self._account_validity = hs.config.account_validity + + if self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + + # Create a background job for culling expired 3PID validity tokens + def start_cull(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "cull_expired_threepid_validation_tokens", + self.cull_expired_threepid_validation_tokens, + ) + + hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) + + @defer.inlineCallbacks + def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): + """Adds an access token for the given user. + + Args: + user_id (str): The user ID. + token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token + valid_until_ms (int|None): when the token is valid until. None for + no expiry. + Raises: + StoreError if there was a problem adding this. + """ + next_id = self._access_tokens_id_gen.get_next() + + yield self.db_pool.simple_insert( + "access_tokens", + { + "id": next_id, + "user_id": user_id, + "token": token, + "device_id": device_id, + "valid_until_ms": valid_until_ms, + }, + desc="add_access_token_to_user", + ) + + def register_user( + self, + user_id, + password_hash=None, + was_guest=False, + make_guest=False, + appservice_id=None, + create_profile_with_displayname=None, + admin=False, + user_type=None, + ): + """Attempts to register an account. + + Args: + user_id (str): The desired user ID to register. + password_hash (str|None): Optional. The password hash for this user. + was_guest (bool): Optional. Whether this is a guest account being + upgraded to a non-guest account. + make_guest (boolean): True if the the new user should be guest, + false to add a regular user account. + appservice_id (str): The ID of the appservice registering the user. + create_profile_with_displayname (unicode): Optionally create a profile for + the user, setting their displayname to the given value + admin (boolean): is an admin user? + user_type (str|None): type of user. One of the values from + api.constants.UserTypes, or None for a normal user. + + Raises: + StoreError if the user_id could not be registered. + + Returns: + Deferred + """ + return self.db_pool.runInteraction( + "register_user", + self._register_user, + user_id, + password_hash, + was_guest, + make_guest, + appservice_id, + create_profile_with_displayname, + admin, + user_type, + ) + + def _register_user( + self, + txn, + user_id, + password_hash, + was_guest, + make_guest, + appservice_id, + create_profile_with_displayname, + admin, + user_type, + ): + user_id_obj = UserID.from_string(user_id) + + now = int(self.clock.time()) + + try: + if was_guest: + # Ensure that the guest user actually exists + # ``allow_none=False`` makes this raise an exception + # if the row isn't in the database. + self.db_pool.simple_select_one_txn( + txn, + "users", + keyvalues={"name": user_id, "is_guest": 1}, + retcols=("name",), + allow_none=False, + ) + + self.db_pool.simple_update_one_txn( + txn, + "users", + keyvalues={"name": user_id, "is_guest": 1}, + updatevalues={ + "password_hash": password_hash, + "upgrade_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + "user_type": user_type, + }, + ) + else: + self.db_pool.simple_insert_txn( + txn, + "users", + values={ + "name": user_id, + "password_hash": password_hash, + "creation_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + "user_type": user_type, + }, + ) + + except self.database_engine.module.IntegrityError: + raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + + if self._account_validity.enabled: + self.set_expiration_date_for_user_txn(txn, user_id) + + if create_profile_with_displayname: + # set a default displayname serverside to avoid ugly race + # between auto-joins and clients trying to set displaynames + # + # *obviously* the 'profiles' table uses localpart for user_id + # while everything else uses the full mxid. + txn.execute( + "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", + (user_id_obj.localpart, create_profile_with_displayname), + ) + + if self.hs.config.stats_enabled: + # we create a new completed user statistics row + + # we don't strictly need current_token since this user really can't + # have any state deltas before now (as it is a new user), but still, + # we include it for completeness. + current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) + self._update_stats_delta_txn( + txn, now, "user", user_id, {}, complete_with_stream_id=current_token + ) + + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + txn.call_after(self.is_guest.invalidate, (user_id,)) + + def record_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> Deferred: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + return self.db_pool.simple_insert( + table="user_external_ids", + values={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="record_user_external_id", + ) + + def user_set_password_hash(self, user_id, password_hash): + """ + NB. This does *not* evict any cache because the one use for this + removes most of the entries subsequently anyway so it would be + pointless. Use flush_user separately. + """ + + def user_set_password_hash_txn(txn): + self.db_pool.simple_update_one_txn( + txn, "users", {"name": user_id}, {"password_hash": password_hash} + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.db_pool.runInteraction( + "user_set_password_hash", user_set_password_hash_txn + ) + + def user_set_consent_version(self, user_id, consent_version): + """Updates the user table to record privacy policy consent + + Args: + user_id (str): full mxid of the user to update + consent_version (str): version of the policy the user has consented + to + + Raises: + StoreError(404) if user not found + """ + + def f(txn): + self.db_pool.simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_version": consent_version}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.db_pool.runInteraction("user_set_consent_version", f) + + def user_set_consent_server_notice_sent(self, user_id, consent_version): + """Updates the user table to record that we have sent the user a server + notice about privacy policy consent + + Args: + user_id (str): full mxid of the user to update + consent_version (str): version of the policy we have notified the + user about + + Raises: + StoreError(404) if user not found + """ + + def f(txn): + self.db_pool.simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_server_notice_sent": consent_version}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) + + def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): + """ + Invalidate access tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str): list of access_tokens IDs which should + *not* be deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + Returns: + defer.Deferred[list[str, int, str|None, int]]: a list of + (token, token id, device id) for each of the deleted tokens + """ + + def f(txn): + keyvalues = {"user_id": user_id} + if device_id is not None: + keyvalues["device_id"] = device_id + + items = keyvalues.items() + where_clause = " AND ".join(k + " = ?" for k, _ in items) + values = [v for _, v in items] + if except_token_id: + where_clause += " AND id != ?" + values.append(except_token_id) + + txn.execute( + "SELECT token, id, device_id FROM access_tokens WHERE %s" + % where_clause, + values, + ) + tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] + + for token, _, _ in tokens_and_devices: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (token,) + ) + + txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) + + return tokens_and_devices + + return self.db_pool.runInteraction("user_delete_access_tokens", f) + + def delete_access_token(self, access_token): + def f(txn): + self.db_pool.simple_delete_one_txn( + txn, table="access_tokens", keyvalues={"token": access_token} + ) + + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (access_token,) + ) + + return self.db_pool.runInteraction("delete_access_token", f) + + @cachedInlineCallbacks() + def is_guest(self, user_id): + res = yield self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="is_guest", + allow_none=True, + desc="is_guest", + ) + + return res if res else False + + def add_user_pending_deactivation(self, user_id): + """ + Adds a user to the table of users who need to be parted from all the rooms they're + in + """ + return self.db_pool.simple_insert( + "users_pending_deactivation", + values={"user_id": user_id}, + desc="add_user_pending_deactivation", + ) + + def del_user_pending_deactivation(self, user_id): + """ + Removes the given user to the table of users who need to be parted from all the + rooms they're in, effectively marking that user as fully deactivated. + """ + # XXX: This should be simple_delete_one but we failed to put a unique index on + # the table, so somehow duplicate entries have ended up in it. + return self.db_pool.simple_delete( + "users_pending_deactivation", + keyvalues={"user_id": user_id}, + desc="del_user_pending_deactivation", + ) + + def get_user_pending_deactivation(self): + """ + Gets one user from the table of users waiting to be parted from all the rooms + they're in. + """ + return self.db_pool.simple_select_one_onecol( + "users_pending_deactivation", + keyvalues={}, + retcol="user_id", + allow_none=True, + desc="get_users_pending_deactivation", + ) + + def validate_threepid_session(self, session_id, client_secret, token, current_ts): + """Attempt to validate a threepid session using a token + + Args: + session_id (str): The id of a validation session + client_secret (str): A unique string provided by the client to + help identify this validation attempt + token (str): A validation token + current_ts (int): The current unix time in milliseconds. Used for + checking token expiry status + + Raises: + ThreepidValidationError: if a matching validation token was not found or has + expired + + Returns: + deferred str|None: A str representing a link to redirect the user + to if there is one. + """ + + # Insert everything into a transaction in order to run atomically + def validate_threepid_session_txn(txn): + row = self.db_pool.simple_select_one_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + retcols=["client_secret", "validated_at"], + allow_none=True, + ) + + if not row: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] + validated_at = row["validated_at"] + + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + + row = self.db_pool.simple_select_one_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id, "token": token}, + retcols=["expires", "next_link"], + allow_none=True, + ) + + if not row: + raise ThreepidValidationError( + 400, "Validation token not found or has expired" + ) + expires = row["expires"] + next_link = row["next_link"] + + # If the session is already validated, no need to revalidate + if validated_at: + return next_link + + if expires <= current_ts: + raise ThreepidValidationError( + 400, "This token has expired. Please request a new one" + ) + + # Looks good. Validate the session + self.db_pool.simple_update_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + updatevalues={"validated_at": self.clock.time_msec()}, + ) + + return next_link + + # Return next_link if it exists + return self.db_pool.runInteraction( + "validate_threepid_session_txn", validate_threepid_session_txn + ) + + def upsert_threepid_validation_session( + self, + medium, + address, + client_secret, + send_attempt, + session_id, + validated_at=None, + ): + """Upsert a threepid validation session + Args: + medium (str): The medium of the 3PID + address (str): The address of the 3PID + client_secret (str): A unique string provided by the client to + help identify this validation attempt + send_attempt (int): The latest send_attempt on this session + session_id (str): The id of this validation session + validated_at (int|None): The unix timestamp in milliseconds of + when the session was marked as valid + """ + insertion_values = { + "medium": medium, + "address": address, + "client_secret": client_secret, + } + + if validated_at: + insertion_values["validated_at"] = validated_at + + return self.db_pool.simple_upsert( + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + values={"last_send_attempt": send_attempt}, + insertion_values=insertion_values, + desc="upsert_threepid_validation_session", + ) + + def start_or_continue_validation_session( + self, + medium, + address, + session_id, + client_secret, + send_attempt, + next_link, + token, + token_expires, + ): + """Creates a new threepid validation session if it does not already + exist and associates a new validation token with it + + Args: + medium (str): The medium of the 3PID + address (str): The address of the 3PID + session_id (str): The id of this validation session + client_secret (str): A unique string provided by the client to + help identify this validation attempt + send_attempt (int): The latest send_attempt on this session + next_link (str|None): The link to redirect the user to upon + successful validation + token (str): The validation token + token_expires (int): The timestamp for which after the token + will no longer be valid + """ + + def start_or_continue_validation_session_txn(txn): + # Create or update a validation session + self.db_pool.simple_upsert_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + values={"last_send_attempt": send_attempt}, + insertion_values={ + "medium": medium, + "address": address, + "client_secret": client_secret, + }, + ) + + # Create a new validation token with this session ID + self.db_pool.simple_insert_txn( + txn, + table="threepid_validation_token", + values={ + "session_id": session_id, + "token": token, + "next_link": next_link, + "expires": token_expires, + }, + ) + + return self.db_pool.runInteraction( + "start_or_continue_validation_session", + start_or_continue_validation_session_txn, + ) + + def cull_expired_threepid_validation_tokens(self): + """Remove threepid validation tokens with expiry dates that have passed""" + + def cull_expired_threepid_validation_tokens_txn(txn, ts): + sql = """ + DELETE FROM threepid_validation_token WHERE + expires < ? + """ + return txn.execute(sql, (ts,)) + + return self.db_pool.runInteraction( + "cull_expired_threepid_validation_tokens", + cull_expired_threepid_validation_tokens_txn, + self.clock.time_msec(), + ) + + @defer.inlineCallbacks + def set_user_deactivated_status(self, user_id, deactivated): + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id (str): The ID of the user to set the status for. + deactivated (bool): The value to set for `deactivated`. + """ + + yield self.db_pool.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) + + def set_user_deactivated_status_txn(self, txn, user_id, deactivated): + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) + + @defer.inlineCallbacks + def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.db_pool.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + yield self.db_pool.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.db_pool.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, + ) + + +def find_max_generated_user_id_localpart(cur: Cursor) -> int: + """ + Gets the localpart of the max current generated user ID. + + Generated user IDs are integers, so we find the largest integer user ID + already taken and return that. + """ + + # We bound between '@0' and '@a' to avoid pulling the entire table + # out. + cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") + + regex = re.compile(r"^@(\d+):") + + max_found = 0 + + for (user_id,) in cur: + match = regex.search(user_id) + if match: + max_found = max(int(match.group(1)), max_found) + return max_found diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py new file mode 100644 index 0000000000..cf9ba51205 --- /dev/null +++ b/synapse/storage/databases/main/rejections.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class RejectionsStore(SQLBaseStore): + def get_rejection_reason(self, event_id): + return self.db_pool.simple_select_one_onecol( + table="rejections", + retcol="reason", + keyvalues={"event_id": event_id}, + allow_none=True, + desc="get_rejection_reason", + ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py new file mode 100644 index 0000000000..b81f1449b7 --- /dev/null +++ b/synapse/storage/databases/main/relations.py @@ -0,0 +1,327 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import attr + +from synapse.api.constants import RelationTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.databases.main.stream import generate_pagination_where_clause +from synapse.storage.relations import ( + AggregationPaginationToken, + PaginationChunk, + RelationPaginationToken, +) +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +class RelationsWorkerStore(SQLBaseStore): + @cached(tree=True) + def get_relations_for_event( + self, + event_id, + relation_type=None, + event_type=None, + aggregation_key=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of relations for an event, ordered by topological ordering. + + Args: + event_id (str): Fetch events that relate to this event ID. + relation_type (str|None): Only fetch events with this relation + type, if given. + event_type (str|None): Only fetch events with this event type, if + given. + aggregation_key (str|None): Only fetch events with this aggregation + key, if given. + limit (int): Only fetch the most recent `limit` events. + direction (str): Whether to fetch the most recent first (`"b"`) or + the oldest first (`"f"`). + from_token (RelationPaginationToken|None): Fetch rows from the given + token, or from the start if None. + to_token (RelationPaginationToken|None): Fetch rows up to the given + token, or up to the end if None. + + Returns: + Deferred[PaginationChunk]: List of event IDs that match relations + requested. The rows are of the form `{"event_id": "..."}`. + """ + + where_clause = ["relates_to_id = ?"] + where_args = [event_id] + + if relation_type is not None: + where_clause.append("relation_type = ?") + where_args.append(relation_type) + + if event_type is not None: + where_clause.append("type = ?") + where_args.append(event_type) + + if aggregation_key: + where_clause.append("aggregation_key = ?") + where_args.append(aggregation_key) + + pagination_clause = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + where_clause.append(pagination_clause) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) + + def _get_recent_references_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + last_topo_id = None + last_stream_id = None + events = [] + for row in txn: + events.append({"event_id": row[0]}) + last_topo_id = row[1] + last_stream_id = row[2] + + next_batch = None + if len(events) > limit and last_topo_id and last_stream_id: + next_batch = RelationPaginationToken(last_topo_id, last_stream_id) + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.db_pool.runInteraction( + "get_recent_references_for_event", _get_recent_references_for_event_txn + ) + + @cached(tree=True) + def get_aggregation_groups_for_event( + self, + event_id, + event_type=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of annotations on the event, grouped by event type and + aggregation key, sorted by count. + + This is used e.g. to get the what and how many reactions have happend + on an event. + + Args: + event_id (str): Fetch events that relate to this event ID. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the `limit` groups. + direction (str): Whether to fetch the highest count first (`"b"`) or + the lowest count first (`"f"`). + from_token (AggregationPaginationToken|None): Fetch rows from the + given token, or from the start if None. + to_token (AggregationPaginationToken|None): Fetch rows up to the + given token, or up to the end if None. + + + Returns: + Deferred[PaginationChunk]: List of groups of annotations that + match. Each row is a dict with `type`, `key` and `count` fields. + """ + + where_clause = ["relates_to_id = ?", "relation_type = ?"] + where_args = [event_id, RelationTypes.ANNOTATION] + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + having_clause = generate_pagination_where_clause( + direction=direction, + column_names=("COUNT(*)", "MAX(stream_ordering)"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + if having_clause: + having_clause = "HAVING " + having_clause + else: + having_clause = "" + + sql = """ + SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE {where_clause} + GROUP BY relation_type, type, aggregation_key + {having_clause} + ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + LIMIT ? + """.format( + where_clause=" AND ".join(where_clause), + order=order, + having_clause=having_clause, + ) + + def _get_aggregation_groups_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + next_batch = None + events = [] + for row in txn: + events.append({"type": row[0], "key": row[1], "count": row[2]}) + next_batch = AggregationPaginationToken(row[2], row[3]) + + if len(events) <= limit: + next_batch = None + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.db_pool.runInteraction( + "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + ) + + @cachedInlineCallbacks() + def get_applicable_edit(self, event_id): + """Get the most recent edit (if any) that has happened for the given + event. + + Correctly handles checking whether edits were allowed to happen. + + Args: + event_id (str): The original event ID + + Returns: + Deferred[EventBase|None]: Returns the most recent edit, if any. + """ + + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the checks post hoc. + + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. + sql = """ + SELECT edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + WHERE + relates_to_id = ? + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts DESC, edit.event_id DESC + LIMIT 1 + """ + + def _get_applicable_edit_txn(txn): + txn.execute(sql, (event_id, RelationTypes.REPLACE)) + row = txn.fetchone() + if row: + return row[0] + + edit_id = yield self.db_pool.runInteraction( + "get_applicable_edit", _get_applicable_edit_txn + ) + + if not edit_id: + return + + edit_event = yield self.get_event(edit_id, allow_none=True) + return edit_event + + def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + """Check if a user has already annotated an event with the same key + (e.g. already liked an event). + + Args: + parent_id (str): The event being annotated + event_type (str): The event type of the annotation + aggregation_key (str): The aggregation key of the annotation + sender (str): The sender of the annotation + + Returns: + Deferred[bool] + """ + + sql = """ + SELECT 1 FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + AND aggregation_key = ? + LIMIT 1; + """ + + def _get_if_user_has_annotated_event(txn): + txn.execute( + sql, + ( + parent_id, + RelationTypes.ANNOTATION, + event_type, + sender, + aggregation_key, + ), + ) + + return bool(txn.fetchone()) + + return self.db_pool.runInteraction( + "get_if_user_has_annotated_event", _get_if_user_has_annotated_event + ) + + +class RelationsStore(RelationsWorkerStore): + pass diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py new file mode 100644 index 0000000000..f4008e6221 --- /dev/null +++ b/synapse/storage/databases/main/room.py @@ -0,0 +1,1429 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import logging +import re +from abc import abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +from canonicaljson import json + +from synapse.api.constants import EventTypes +from synapse.api.errors import StoreError +from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.search import SearchStore +from synapse.types import ThirdPartyInstanceID +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + + +OpsLevel = collections.namedtuple( + "OpsLevel", ("ban_level", "kick_level", "redact_level") +) + +RatelimitOverride = collections.namedtuple( + "RatelimitOverride", ("messages_per_second", "burst_count") +) + + +class RoomSortOrder(Enum): + """ + Enum to define the sorting method used when returning rooms with get_rooms_paginate + + NAME = sort rooms alphabetically by name + JOINED_MEMBERS = sort rooms by membership size, highest to lowest + """ + + # ALPHABETICAL and SIZE are deprecated. + # ALPHABETICAL is the same as NAME. + ALPHABETICAL = "alphabetical" + # SIZE is the same as JOINED_MEMBERS. + SIZE = "size" + NAME = "name" + CANONICAL_ALIAS = "canonical_alias" + JOINED_MEMBERS = "joined_members" + JOINED_LOCAL_MEMBERS = "joined_local_members" + VERSION = "version" + CREATOR = "creator" + ENCRYPTION = "encryption" + FEDERATABLE = "federatable" + PUBLIC = "public" + JOIN_RULES = "join_rules" + GUEST_ACCESS = "guest_access" + HISTORY_VISIBILITY = "history_visibility" + STATE_EVENTS = "state_events" + + +class RoomWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RoomWorkerStore, self).__init__(database, db_conn, hs) + + self.config = hs.config + + def get_room(self, room_id): + """Retrieve a room. + + Args: + room_id (str): The ID of the room to retrieve. + Returns: + A dict containing the room information, or None if the room is unknown. + """ + return self.db_pool.simple_select_one( + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("room_id", "is_public", "creator"), + desc="get_room", + allow_none=True, + ) + + def get_room_with_stats(self, room_id: str): + """Retrieve room with statistics. + + Args: + room_id: The ID of the room to retrieve. + Returns: + A dict containing the room information, or None if the room is unknown. + """ + + def get_room_with_stats_txn(txn, room_id): + sql = """ + SELECT room_id, state.name, state.canonical_alias, curr.joined_members, + curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, + rooms.creator, state.encryption, state.is_federatable AS federatable, + rooms.is_public AS public, state.join_rules, state.guest_access, + state.history_visibility, curr.current_state_events AS state_events + FROM rooms + LEFT JOIN room_stats_state state USING (room_id) + LEFT JOIN room_stats_current curr USING (room_id) + WHERE room_id = ? + """ + txn.execute(sql, [room_id]) + # Catch error if sql returns empty result to return "None" instead of an error + try: + res = self.db_pool.cursor_to_dict(txn)[0] + except IndexError: + return None + + res["federatable"] = bool(res["federatable"]) + res["public"] = bool(res["public"]) + return res + + return self.db_pool.runInteraction( + "get_room_with_stats", get_room_with_stats_txn, room_id + ) + + def get_public_room_ids(self): + return self.db_pool.simple_select_onecol( + table="rooms", + keyvalues={"is_public": True}, + retcol="room_id", + desc="get_public_room_ids", + ) + + def count_public_rooms(self, network_tuple, ignore_non_federatable): + """Counts the number of public rooms as tracked in the room_stats_current + and room_stats_state table. + + Args: + network_tuple (ThirdPartyInstanceID|None) + ignore_non_federatable (bool): If true filters out non-federatable rooms + """ + + def _count_public_rooms_txn(txn): + query_args = [] + + if network_tuple: + if network_tuple.appservice_id: + published_sql = """ + SELECT room_id from appservice_room_list + WHERE appservice_id = ? AND network_id = ? + """ + query_args.append(network_tuple.appservice_id) + query_args.append(network_tuple.network_id) + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + """ + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + UNION SELECT room_id from appservice_room_list + """ + + sql = """ + SELECT + COALESCE(COUNT(*), 0) + FROM ( + %(published_sql)s + ) published + INNER JOIN room_stats_state USING (room_id) + INNER JOIN room_stats_current USING (room_id) + WHERE + ( + join_rules = 'public' OR history_visibility = 'world_readable' + ) + AND joined_members > 0 + """ % { + "published_sql": published_sql + } + + txn.execute(sql, query_args) + return txn.fetchone()[0] + + return self.db_pool.runInteraction( + "count_public_rooms", _count_public_rooms_txn + ) + + async def get_largest_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + search_filter: Optional[dict], + limit: Optional[int], + bounds: Optional[Tuple[int, str]], + forwards: bool, + ignore_non_federatable: bool = False, + ): + """Gets the largest public rooms (where largest is in terms of joined + members, as tracked in the statistics table). + + Args: + network_tuple + search_filter + limit: Maxmimum number of rows to return, unlimited otherwise. + bounds: An uppoer or lower bound to apply to result set if given, + consists of a joined member count and room_id (these are + excluded from result set). + forwards: true iff going forwards, going backwards otherwise + ignore_non_federatable: If true filters out non-federatable rooms. + + Returns: + Rooms in order: biggest number of joined users first. + We then arbitrarily use the room_id as a tie breaker. + + """ + + where_clauses = [] + query_args = [] + + if network_tuple: + if network_tuple.appservice_id: + published_sql = """ + SELECT room_id from appservice_room_list + WHERE appservice_id = ? AND network_id = ? + """ + query_args.append(network_tuple.appservice_id) + query_args.append(network_tuple.network_id) + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + """ + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + UNION SELECT room_id from appservice_room_list + """ + + # Work out the bounds if we're given them, these bounds look slightly + # odd, but are designed to help query planner use indices by pulling + # out a common bound. + if bounds: + last_joined_members, last_room_id = bounds + if forwards: + where_clauses.append( + """ + joined_members <= ? AND ( + joined_members < ? OR room_id < ? + ) + """ + ) + else: + where_clauses.append( + """ + joined_members >= ? AND ( + joined_members > ? OR room_id > ? + ) + """ + ) + + query_args += [last_joined_members, last_joined_members, last_room_id] + + if ignore_non_federatable: + where_clauses.append("is_federatable") + + if search_filter and search_filter.get("generic_search_term", None): + search_term = "%" + search_filter["generic_search_term"] + "%" + + where_clauses.append( + """ + ( + LOWER(name) LIKE ? + OR LOWER(topic) LIKE ? + OR LOWER(canonical_alias) LIKE ? + ) + """ + ) + query_args += [ + search_term.lower(), + search_term.lower(), + search_term.lower(), + ] + + where_clause = "" + if where_clauses: + where_clause = " AND " + " AND ".join(where_clauses) + + sql = """ + SELECT + room_id, name, topic, canonical_alias, joined_members, + avatar, history_visibility, joined_members, guest_access + FROM ( + %(published_sql)s + ) published + INNER JOIN room_stats_state USING (room_id) + INNER JOIN room_stats_current USING (room_id) + WHERE + ( + join_rules = 'public' OR history_visibility = 'world_readable' + ) + AND joined_members > 0 + %(where_clause)s + ORDER BY joined_members %(dir)s, room_id %(dir)s + """ % { + "published_sql": published_sql, + "where_clause": where_clause, + "dir": "DESC" if forwards else "ASC", + } + + if limit is not None: + query_args.append(limit) + + sql += """ + LIMIT ? + """ + + def _get_largest_public_rooms_txn(txn): + txn.execute(sql, query_args) + + results = self.db_pool.cursor_to_dict(txn) + + if not forwards: + results.reverse() + + return results + + ret_val = await self.db_pool.runInteraction( + "get_largest_public_rooms", _get_largest_public_rooms_txn + ) + return ret_val + + @cached(max_entries=10000) + def is_room_blocked(self, room_id): + return self.db_pool.simple_select_one_onecol( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + retcol="1", + allow_none=True, + desc="is_room_blocked", + ) + + async def get_rooms_paginate( + self, + start: int, + limit: int, + order_by: RoomSortOrder, + reverse_order: bool, + search_term: Optional[str], + ) -> Tuple[List[Dict[str, Any]], int]: + """Function to retrieve a paginated list of rooms as json. + + Args: + start: offset in the list + limit: maximum amount of rooms to retrieve + order_by: the sort order of the returned list + reverse_order: whether to reverse the room list + search_term: a string to filter room names by + Returns: + A list of room dicts and an integer representing the total number of + rooms that exist given this query + """ + # Filter room names by a string + where_statement = "" + if search_term: + where_statement = "WHERE state.name LIKE ?" + + # Our postgres db driver converts ? -> %s in SQL strings as that's the + # placeholder for postgres. + # HOWEVER, if you put a % into your SQL then everything goes wibbly. + # To get around this, we're going to surround search_term with %'s + # before giving it to the database in python instead + search_term = "%" + search_term + "%" + + # Set ordering + if RoomSortOrder(order_by) == RoomSortOrder.SIZE: + # Deprecated in favour of RoomSortOrder.JOINED_MEMBERS + order_by_column = "curr.joined_members" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL: + # Deprecated in favour of RoomSortOrder.NAME + order_by_column = "state.name" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.NAME: + order_by_column = "state.name" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.CANONICAL_ALIAS: + order_by_column = "state.canonical_alias" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_MEMBERS: + order_by_column = "curr.joined_members" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_LOCAL_MEMBERS: + order_by_column = "curr.local_users_in_room" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.VERSION: + order_by_column = "rooms.room_version" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.CREATOR: + order_by_column = "rooms.creator" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.ENCRYPTION: + order_by_column = "state.encryption" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.FEDERATABLE: + order_by_column = "state.is_federatable" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.PUBLIC: + order_by_column = "rooms.is_public" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.JOIN_RULES: + order_by_column = "state.join_rules" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.GUEST_ACCESS: + order_by_column = "state.guest_access" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.HISTORY_VISIBILITY: + order_by_column = "state.history_visibility" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.STATE_EVENTS: + order_by_column = "curr.current_state_events" + order_by_asc = False + else: + raise StoreError( + 500, "Incorrect value for order_by provided: %s" % order_by + ) + + # Whether to return the list in reverse order + if reverse_order: + # Flip the boolean + order_by_asc = not order_by_asc + + # Create one query for getting the limited number of events that the user asked + # for, and another query for getting the total number of events that could be + # returned. Thus allowing us to see if there are more events to paginate through + info_sql = """ + SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members, + curr.local_users_in_room, rooms.room_version, rooms.creator, + state.encryption, state.is_federatable, rooms.is_public, state.join_rules, + state.guest_access, state.history_visibility, curr.current_state_events + FROM room_stats_state state + INNER JOIN room_stats_current curr USING (room_id) + INNER JOIN rooms USING (room_id) + %s + ORDER BY %s %s + LIMIT ? + OFFSET ? + """ % ( + where_statement, + order_by_column, + "ASC" if order_by_asc else "DESC", + ) + + # Use a nested SELECT statement as SQL can't count(*) with an OFFSET + count_sql = """ + SELECT count(*) FROM ( + SELECT room_id FROM room_stats_state state + %s + ) AS get_room_ids + """ % ( + where_statement, + ) + + def _get_rooms_paginate_txn(txn): + # Execute the data query + sql_values = (limit, start) + if search_term: + # Add the search term into the WHERE clause + sql_values = (search_term,) + sql_values + txn.execute(info_sql, sql_values) + + # Refactor room query data into a structured dictionary + rooms = [] + for room in txn: + rooms.append( + { + "room_id": room[0], + "name": room[1], + "canonical_alias": room[2], + "joined_members": room[3], + "joined_local_members": room[4], + "version": room[5], + "creator": room[6], + "encryption": room[7], + "federatable": room[8], + "public": room[9], + "join_rules": room[10], + "guest_access": room[11], + "history_visibility": room[12], + "state_events": room[13], + } + ) + + # Execute the count query + + # Add the search term into the WHERE clause if present + sql_values = (search_term,) if search_term else () + txn.execute(count_sql, sql_values) + + room_count = txn.fetchone() + return rooms, room_count[0] + + return await self.db_pool.runInteraction( + "get_rooms_paginate", _get_rooms_paginate_txn, + ) + + @cached(max_entries=10000) + async def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user + + Args: + user_id (str) + + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = await self.db_pool.simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) + + if row: + return RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + ) + else: + return None + + @cached() + async def get_retention_policy_for_room(self, room_id): + """Get the retention policy for a given room. + + If no retention policy has been found for this room, returns a policy defined + by the configured default policy (which has None as both the 'min_lifetime' and + the 'max_lifetime' if no default policy has been defined in the server's + configuration). + + Args: + room_id (str): The ID of the room to get the retention policy of. + + Returns: + dict[int, int]: "min_lifetime" and "max_lifetime" for this room. + """ + + def get_retention_policy_for_room_txn(txn): + txn.execute( + """ + SELECT min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + WHERE room_id = ?; + """, + (room_id,), + ) + + return self.db_pool.cursor_to_dict(txn) + + ret = await self.db_pool.runInteraction( + "get_retention_policy_for_room", get_retention_policy_for_room_txn, + ) + + # If we don't know this room ID, ret will be None, in this case return the default + # policy. + if not ret: + return { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } + + row = ret[0] + + # If one of the room's policy's attributes isn't defined, use the matching + # attribute from the default policy. + # The default values will be None if no default policy has been defined, or if one + # of the attributes is missing from the default policy. + if row["min_lifetime"] is None: + row["min_lifetime"] = self.config.retention_default_min_lifetime + + if row["max_lifetime"] is None: + row["max_lifetime"] = self.config.retention_default_max_lifetime + + return row + + def get_media_mxcs_in_room(self, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + + def _get_media_mxcs_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + local_media_mxcs = [] + remote_media_mxcs = [] + + # Convert the IDs to MXC URIs + for media_id in local_mxcs: + local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) + for hostname, media_id in remote_mxcs: + remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs + + return self.db_pool.runInteraction( + "get_media_ids_in_room", _get_media_mxcs_in_room_txn + ) + + def quarantine_media_ids_in_room(self, room_id, quarantined_by): + """For a room loops through all events with media and quarantines + the associated media + """ + + logger.info("Quarantining media in room: %s", room_id) + + def _quarantine_media_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + return self._quarantine_media_txn( + txn, local_mxcs, remote_mxcs, quarantined_by + ) + + return self.db_pool.runInteraction( + "quarantine_media_in_room", _quarantine_media_in_room_txn + ) + + def _get_media_mxcs_in_room_txn(self, txn, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + txn (cursor) + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + + sql = """ + SELECT stream_ordering, json FROM events + JOIN event_json USING (room_id, event_id) + WHERE room_id = ? + %(where_clause)s + AND contains_url = ? AND outlier = ? + ORDER BY stream_ordering DESC + LIMIT ? + """ + txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100)) + + local_media_mxcs = [] + remote_media_mxcs = [] + + while True: + next_token = None + for stream_ordering, content_json in txn: + next_token = stream_ordering + event_json = db_to_json(content_json) + content = event_json["content"] + content_url = content.get("url") + thumbnail_url = content.get("info", {}).get("thumbnail_url") + + for url in (content_url, thumbnail_url): + if not url: + continue + matches = mxc_re.match(url) + if matches: + hostname = matches.group(1) + media_id = matches.group(2) + if hostname == self.hs.hostname: + local_media_mxcs.append(media_id) + else: + remote_media_mxcs.append((hostname, media_id)) + + if next_token is None: + # We've gone through the whole room, so we're finished. + break + + txn.execute( + sql % {"where_clause": "AND stream_ordering < ?"}, + (room_id, next_token, True, False, 100), + ) + + return local_media_mxcs, remote_media_mxcs + + def quarantine_media_by_id( + self, server_name: str, media_id: str, quarantined_by: str, + ): + """quarantines a single local or remote media id + + Args: + server_name: The name of the server that holds this media + media_id: The ID of the media to be quarantined + quarantined_by: The user ID that initiated the quarantine request + """ + logger.info("Quarantining media: %s/%s", server_name, media_id) + is_local = server_name == self.config.server_name + + def _quarantine_media_by_id_txn(txn): + local_mxcs = [media_id] if is_local else [] + remote_mxcs = [(server_name, media_id)] if not is_local else [] + + return self._quarantine_media_txn( + txn, local_mxcs, remote_mxcs, quarantined_by + ) + + return self.db_pool.runInteraction( + "quarantine_media_by_user", _quarantine_media_by_id_txn + ) + + def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + """quarantines all local media associated with a single user + + Args: + user_id: The ID of the user to quarantine media of + quarantined_by: The ID of the user who made the quarantine request + """ + + def _quarantine_media_by_user_txn(txn): + local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) + return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) + + return self.db_pool.runInteraction( + "quarantine_media_by_user", _quarantine_media_by_user_txn + ) + + def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): + """Retrieves local media IDs by a given user + + Args: + txn (cursor) + user_id: The ID of the user to retrieve media IDs of + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + # Local media + sql = """ + SELECT media_id + FROM local_media_repository + WHERE user_id = ? + """ + if filter_quarantined: + sql += "AND quarantined_by IS NULL" + txn.execute(sql, (user_id,)) + + local_media_ids = [row[0] for row in txn] + + # TODO: Figure out all remote media a user has referenced in a message + + return local_media_ids + + def _quarantine_media_txn( + self, + txn, + local_mxcs: List[str], + remote_mxcs: List[Tuple[str, str]], + quarantined_by: str, + ) -> int: + """Quarantine local and remote media items + + Args: + txn (cursor) + local_mxcs: A list of local mxc URLs + remote_mxcs: A list of (remote server, media id) tuples representing + remote mxc URLs + quarantined_by: The ID of the user who initiated the quarantine request + Returns: + The total number of media items quarantined + """ + # Update all the tables to set the quarantined_by flag + txn.executemany( + """ + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? AND safe_from_quarantine = ? + """, + ((quarantined_by, media_id, False) for media_id in local_mxcs), + ) + # Note that a rowcount of -1 can be used to indicate no rows were affected. + total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 + + txn.executemany( + """ + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE media_origin = ? AND media_id = ? + """, + ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), + ) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + return total_media_quarantined + + async def get_all_new_public_rooms( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for public rooms replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: + return [], current_id, False + + def get_all_new_public_rooms(txn): + sql = """ + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (last_id, current_id, limit)) + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) + + +class RoomBackgroundUpdateStore(SQLBaseStore): + REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" + ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.config = hs.config + + self.db_pool.updates.register_background_update_handler( + "insert_room_retention", self._background_insert_retention, + ) + + self.db_pool.updates.register_background_update_handler( + self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, + self._remove_tombstoned_rooms_from_directory, + ) + + self.db_pool.updates.register_background_update_handler( + self.ADD_ROOMS_ROOM_VERSION_COLUMN, + self._background_add_rooms_room_version_column, + ) + + async def _background_insert_retention(self, progress, batch_size): + """Retrieves a list of all rooms within a range and inserts an entry for each of + them into the room_retention table. + NULLs the property's columns if missing from the retention event in the room's + state (or NULLs all of them if there's no retention event in the room's state), + so that we fall back to the server's retention policy. + """ + + last_room = progress.get("room_id", "") + + def _background_insert_retention_txn(txn): + txn.execute( + """ + SELECT state.room_id, state.event_id, events.json + FROM current_state_events as state + LEFT JOIN event_json AS events ON (state.event_id = events.event_id) + WHERE state.room_id > ? AND state.type = '%s' + ORDER BY state.room_id ASC + LIMIT ?; + """ + % EventTypes.Retention, + (last_room, batch_size), + ) + + rows = self.db_pool.cursor_to_dict(txn) + + if not rows: + return True + + for row in rows: + if not row["json"]: + retention_policy = {} + else: + ev = db_to_json(row["json"]) + retention_policy = ev["content"] + + self.db_pool.simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": row["room_id"], + "event_id": row["event_id"], + "min_lifetime": retention_policy.get("min_lifetime"), + "max_lifetime": retention_policy.get("max_lifetime"), + }, + ) + + logger.info("Inserted %d rows into room_retention", len(rows)) + + self.db_pool.updates._background_update_progress_txn( + txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} + ) + + if batch_size > len(rows): + return True + else: + return False + + end = await self.db_pool.runInteraction( + "insert_room_retention", _background_insert_retention_txn, + ) + + if end: + await self.db_pool.updates._end_background_update("insert_room_retention") + + return batch_size + + async def _background_add_rooms_room_version_column( + self, progress: dict, batch_size: int + ): + """Background update to go and add room version inforamtion to `rooms` + table from `current_state_events` table. + """ + + last_room_id = progress.get("room_id", "") + + def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): + sql = """ + SELECT room_id, json FROM current_state_events + INNER JOIN event_json USING (room_id, event_id) + WHERE room_id > ? AND type = 'm.room.create' AND state_key = '' + ORDER BY room_id + LIMIT ? + """ + + txn.execute(sql, (last_room_id, batch_size)) + + updates = [] + for room_id, event_json in txn: + event_dict = db_to_json(event_json) + room_version_id = event_dict.get("content", {}).get( + "room_version", RoomVersions.V1.identifier + ) + + creator = event_dict.get("content").get("creator") + + updates.append((room_id, creator, room_version_id)) + + if not updates: + return True + + new_last_room_id = "" + for room_id, creator, room_version_id in updates: + # We upsert here just in case we don't already have a row, + # mainly for paranoia as much badness would happen if we don't + # insert the row and then try and get the room version for the + # room. + self.db_pool.simple_upsert_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + values={"room_version": room_version_id}, + insertion_values={"is_public": False, "creator": creator}, + ) + new_last_room_id = room_id + + self.db_pool.updates._background_update_progress_txn( + txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id} + ) + + return False + + end = await self.db_pool.runInteraction( + "_background_add_rooms_room_version_column", + _background_add_rooms_room_version_column_txn, + ) + + if end: + await self.db_pool.updates._end_background_update( + self.ADD_ROOMS_ROOM_VERSION_COLUMN + ) + + return batch_size + + async def _remove_tombstoned_rooms_from_directory( + self, progress, batch_size + ) -> int: + """Removes any rooms with tombstone events from the room directory + + Nowadays this is handled by the room upgrade handler, but we may have some + that got left behind + """ + + last_room = progress.get("room_id", "") + + def _get_rooms(txn): + txn.execute( + """ + SELECT room_id + FROM rooms r + INNER JOIN current_state_events cse USING (room_id) + WHERE room_id > ? AND r.is_public + AND cse.type = '%s' AND cse.state_key = '' + ORDER BY room_id ASC + LIMIT ?; + """ + % EventTypes.Tombstone, + (last_room, batch_size), + ) + + return [row[0] for row in txn] + + rooms = await self.db_pool.runInteraction( + "get_tombstoned_directory_rooms", _get_rooms + ) + + if not rooms: + await self.db_pool.updates._end_background_update( + self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE + ) + return 0 + + for room_id in rooms: + logger.info("Removing tombstoned room %s from the directory", room_id) + await self.set_room_is_public(room_id, False) + + await self.db_pool.updates._background_update_progress( + self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]} + ) + + return len(rooms) + + @abstractmethod + def set_room_is_public(self, room_id, is_public): + # this will need to be implemented if a background update is performed with + # existing (tombstoned, public) rooms in the database. + # + # It's overridden by RoomStore for the synapse master. + raise NotImplementedError() + + +class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) + + self.config = hs.config + + async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion): + """Ensure that the room is stored in the table + + Called when we join a room over federation, and overwrites any room version + currently in the table. + """ + await self.db_pool.simple_upsert( + desc="upsert_room_on_join", + table="rooms", + keyvalues={"room_id": room_id}, + values={"room_version": room_version.identifier}, + insertion_values={"is_public": False, "creator": ""}, + # rooms has a unique constraint on room_id, so no need to lock when doing an + # emulated upsert. + lock=False, + ) + + async def store_room( + self, + room_id: str, + room_creator_user_id: str, + is_public: bool, + room_version: RoomVersion, + ): + """Stores a room. + + Args: + room_id: The desired room ID, can be None. + room_creator_user_id: The user ID of the room creator. + is_public: True to indicate that this room should appear in + public room lists. + room_version: The version of the room + Raises: + StoreError if the room could not be stored. + """ + try: + + def store_room_txn(txn, next_id): + self.db_pool.simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": room_creator_user_id, + "is_public": is_public, + "room_version": room_version.identifier, + }, + ) + if is_public: + self.db_pool.simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + }, + ) + + with self._public_room_id_gen.get_next() as next_id: + await self.db_pool.runInteraction( + "store_room_txn", store_room_txn, next_id + ) + except Exception as e: + logger.error("store_room with room_id=%s failed: %s", room_id, e) + raise StoreError(500, "Problem creating room.") + + async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): + """ + When we receive an invite over federation, store the version of the room if we + don't already know the room version. + """ + await self.db_pool.simple_upsert( + desc="maybe_store_room_on_invite", + table="rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={ + "room_version": room_version.identifier, + "is_public": False, + "creator": "", + }, + # rooms has a unique constraint on room_id, so no need to lock when doing an + # emulated upsert. + lock=False, + ) + + async def set_room_is_public(self, room_id, is_public): + def set_room_is_public_txn(txn, next_id): + self.db_pool.simple_update_one_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + ) + + entries = self.db_pool.simple_select_list_txn( + txn, + table="public_room_list_stream", + keyvalues={ + "room_id": room_id, + "appservice_id": None, + "network_id": None, + }, + retcols=("stream_id", "visibility"), + ) + + entries.sort(key=lambda r: r["stream_id"]) + + add_to_stream = True + if entries: + add_to_stream = bool(entries[-1]["visibility"]) != is_public + + if add_to_stream: + self.db_pool.simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + "appservice_id": None, + "network_id": None, + }, + ) + + with self._public_room_id_gen.get_next() as next_id: + await self.db_pool.runInteraction( + "set_room_is_public", set_room_is_public_txn, next_id + ) + self.hs.get_notifier().on_new_replication_data() + + async def set_room_is_public_appservice( + self, room_id, appservice_id, network_id, is_public + ): + """Edit the appservice/network specific public room list. + + Each appservice can have a number of published room lists associated + with them, keyed off of an appservice defined `network_id`, which + basically represents a single instance of a bridge to a third party + network. + + Args: + room_id (str) + appservice_id (str) + network_id (str) + is_public (bool): Whether to publish or unpublish the room from the + list. + """ + + def set_room_is_public_appservice_txn(txn, next_id): + if is_public: + try: + self.db_pool.simple_insert_txn( + txn, + table="appservice_room_list", + values={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id, + }, + ) + except self.database_engine.module.IntegrityError: + # We've already inserted, nothing to do. + return + else: + self.db_pool.simple_delete_txn( + txn, + table="appservice_room_list", + keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id, + }, + ) + + entries = self.db_pool.simple_select_list_txn( + txn, + table="public_room_list_stream", + keyvalues={ + "room_id": room_id, + "appservice_id": appservice_id, + "network_id": network_id, + }, + retcols=("stream_id", "visibility"), + ) + + entries.sort(key=lambda r: r["stream_id"]) + + add_to_stream = True + if entries: + add_to_stream = bool(entries[-1]["visibility"]) != is_public + + if add_to_stream: + self.db_pool.simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + "appservice_id": appservice_id, + "network_id": network_id, + }, + ) + + with self._public_room_id_gen.get_next() as next_id: + await self.db_pool.runInteraction( + "set_room_is_public_appservice", + set_room_is_public_appservice_txn, + next_id, + ) + self.hs.get_notifier().on_new_replication_data() + + def get_room_count(self): + """Retrieve a list of all rooms + """ + + def f(txn): + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 + + return self.db_pool.runInteraction("get_rooms", f) + + def add_event_report( + self, room_id, event_id, user_id, reason, content, received_ts + ): + next_id = self._event_reports_id_gen.get_next() + return self.db_pool.simple_insert( + table="event_reports", + values={ + "id": next_id, + "received_ts": received_ts, + "room_id": room_id, + "event_id": event_id, + "user_id": user_id, + "reason": reason, + "content": json.dumps(content), + }, + desc="add_event_report", + ) + + def get_current_public_room_stream_id(self): + return self._public_room_id_gen.get_current_token() + + async def block_room(self, room_id: str, user_id: str) -> None: + """Marks the room as blocked. Can be called multiple times. + + Args: + room_id: Room to block + user_id: Who blocked it + """ + await self.db_pool.simple_upsert( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={"user_id": user_id}, + desc="block_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) + + async def get_rooms_for_retention_period_in_range( + self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False + ) -> Dict[str, dict]: + """Retrieves all of the rooms within the given retention range. + + Optionally includes the rooms which don't have a retention policy. + + Args: + min_ms: Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, doesn't set a lower limit. + max_ms: Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, doesn't set an upper limit. + include_null: Whether to include rooms which retention policy is NULL + in the returned set. + + Returns: + The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). + """ + + def get_rooms_for_retention_period_in_range_txn(txn): + range_conditions = [] + args = [] + + if min_ms is not None: + range_conditions.append("max_lifetime > ?") + args.append(min_ms) + + if max_ms is not None: + range_conditions.append("max_lifetime <= ?") + args.append(max_ms) + + # Do a first query which will retrieve the rooms that have a retention policy + # in their current state. + sql = """ + SELECT room_id, min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + """ + + if len(range_conditions): + sql += " WHERE (" + " AND ".join(range_conditions) + ")" + + if include_null: + sql += " OR max_lifetime IS NULL" + + txn.execute(sql, args) + + rows = self.db_pool.cursor_to_dict(txn) + rooms_dict = {} + + for row in rows: + rooms_dict[row["room_id"]] = { + "min_lifetime": row["min_lifetime"], + "max_lifetime": row["max_lifetime"], + } + + if include_null: + # If required, do a second query that retrieves all of the rooms we know + # of so we can handle rooms with no retention policy. + sql = "SELECT DISTINCT room_id FROM current_state_events" + + txn.execute(sql) + + rows = self.db_pool.cursor_to_dict(txn) + + # If a room isn't already in the dict (i.e. it doesn't have a retention + # policy in its state), add it with a null policy. + for row in rows: + if row["room_id"] not in rooms_dict: + rooms_dict[row["room_id"]] = { + "min_lifetime": None, + "max_lifetime": None, + } + + return rooms_dict + + rooms = await self.db_pool.runInteraction( + "get_rooms_for_retention_period_in_range", + get_rooms_for_retention_period_in_range_txn, + ) + + return rooms diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py new file mode 100644 index 0000000000..7c5be251bd --- /dev/null +++ b/synapse/storage/databases/main/roommember.py @@ -0,0 +1,1139 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Iterable, List, Set + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import ( + LoggingTransaction, + SQLBaseStore, + db_to_json, + make_in_list_sql_clause, +) +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.engines import Sqlite3Engine +from synapse.storage.roommember import ( + GetRoomsForUserWithStreamOrdering, + MemberSummary, + ProfileInfo, + RoomsForUser, +) +from synapse.types import Collection, get_domain_from_id +from synapse.util.async_helpers import Linearizer +from synapse.util.caches import intern_string +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" +_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" + + +class RoomMemberWorkerStore(EventsWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) + + # Is the current_state_events.membership up to date? Or is the + # background update still running? + self._current_state_events_membership_up_to_date = False + + txn = LoggingTransaction( + db_conn.cursor(), + name="_check_safe_current_state_events_membership_updated", + database_engine=self.database_engine, + ) + self._check_safe_current_state_events_membership_updated_txn(txn) + txn.close() + + if self.hs.config.metrics_flags.known_servers: + self._known_servers_count = 1 + self.hs.get_clock().looping_call( + run_as_background_process, + 60 * 1000, + "_count_known_servers", + self._count_known_servers, + ) + self.hs.get_clock().call_later( + 1000, + run_as_background_process, + "_count_known_servers", + self._count_known_servers, + ) + LaterGauge( + "synapse_federation_known_servers", + "", + [], + lambda: self._known_servers_count, + ) + + @defer.inlineCallbacks + def _count_known_servers(self): + """ + Count the servers that this server knows about. + + The statistic is stored on the class for the + `synapse_federation_known_servers` LaterGauge to collect. + """ + + def _transact(txn): + if isinstance(self.database_engine, Sqlite3Engine): + query = """ + SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) + FROM ( + SELECT rm.user_id as user_id, instr(rm.user_id, ':') + AS pos FROM room_memberships as rm + INNER JOIN current_state_events as c ON rm.event_id = c.event_id + WHERE c.type = 'm.room.member' + ) as out + """ + else: + query = """ + SELECT COUNT(DISTINCT split_part(state_key, ':', 2)) + FROM current_state_events + WHERE type = 'm.room.member' AND membership = 'join'; + """ + txn.execute(query) + return list(txn)[0][0] + + count = yield self.db_pool.runInteraction("get_known_servers", _transact) + + # We always know about ourselves, even if we have nothing in + # room_memberships (for example, the server is new). + self._known_servers_count = max([count, 1]) + return self._known_servers_count + + def _check_safe_current_state_events_membership_updated_txn(self, txn): + """Checks if it is safe to assume the new current_state_events + membership column is up to date + """ + + pending_update = self.db_pool.simple_select_one_txn( + txn, + table="background_updates", + keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, + retcols=["update_name"], + allow_none=True, + ) + + self._current_state_events_membership_up_to_date = not pending_update + + # If the update is still running, reschedule to run. + if pending_update: + self._clock.call_later( + 15.0, + run_as_background_process, + "_check_safe_current_state_events_membership_updated", + self.db_pool.runInteraction, + "_check_safe_current_state_events_membership_updated", + self._check_safe_current_state_events_membership_updated_txn, + ) + + @cached(max_entries=100000, iterable=True) + def get_users_in_room(self, room_id): + return self.db_pool.runInteraction( + "get_users_in_room", self.get_users_in_room_txn, room_id + ) + + def get_users_in_room_txn(self, txn, room_id): + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT state_key FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? AND membership = ? + """ + else: + sql = """ + SELECT state_key FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? + """ + + txn.execute(sql, (room_id, Membership.JOIN)) + return [r[0] for r in txn] + + @cached(max_entries=100000) + def get_room_summary(self, room_id): + """ Get the details of a room roughly suitable for use by the room + summary extension to /sync. Useful when lazy loading room members. + Args: + room_id (str): The room ID to query + Returns: + Deferred[dict[str, MemberSummary]: + dict of membership states, pointing to a MemberSummary named tuple. + """ + + def _get_room_summary_txn(txn): + # first get counts. + # We do this all in one transaction to keep the cache small. + # FIXME: get rid of this when we have room_stats + + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT count(*), membership FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + GROUP BY membership + """ + else: + sql = """ + SELECT count(*), m.membership FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + res = {} + for count, membership in txn: + summary = res.setdefault(membership, MemberSummary([], count)) + + # we order by membership and then fairly arbitrarily by event_id so + # heroes are consistent + if self._current_state_events_membership_up_to_date: + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT state_key, membership, event_id + FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + ORDER BY + CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + event_id ASC + LIMIT ? + """ + else: + sql = """ + SELECT c.state_key, m.membership, c.event_id + FROM room_memberships as m + INNER JOIN current_state_events as c USING (room_id, event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + ORDER BY + CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + c.event_id ASC + LIMIT ? + """ + + # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. + txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) + for user_id, membership, event_id in txn: + summary = res[membership] + # we will always have a summary for this membership type at this + # point given the summary currently contains the counts. + members = summary.members + members.append((user_id, event_id)) + + return res + + return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) + + def _get_user_counts_in_room_txn(self, txn, room_id): + """ + Get the user count in a room by membership. + + Args: + room_id (str) + membership (Membership) + + Returns: + Deferred[int] + """ + sql = """ + SELECT m.membership, count(*) FROM room_memberships as m + INNER JOIN current_state_events as c USING(event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + return {row[0]: row[1] for row in txn} + + @cached() + def get_invited_rooms_for_local_user(self, user_id): + """ Get all the rooms the *local* user is invited to + + Args: + user_id (str): The user ID. + Returns: + A deferred list of RoomsForUser. + """ + + return self.get_rooms_for_local_user_where_membership_is( + user_id, [Membership.INVITE] + ) + + @defer.inlineCallbacks + def get_invite_for_local_user_in_room(self, user_id, room_id): + """Gets the invite for the given *local* user and room + + Args: + user_id (str) + room_id (str) + + Returns: + Deferred: Resolves to either a RoomsForUser or None if no invite was + found. + """ + invites = yield self.get_invited_rooms_for_local_user(user_id) + for invite in invites: + if invite.room_id == room_id: + return invite + return None + + @defer.inlineCallbacks + def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): + """ Get all the rooms for this *local* user where the membership for this user + matches one in the membership list. + + Filters out forgotten rooms. + + Args: + user_id (str): The user ID. + membership_list (list): A list of synapse.api.constants.Membership + values which the user must be in. + + Returns: + Deferred[list[RoomsForUser]] + """ + if not membership_list: + return defer.succeed(None) + + rooms = yield self.db_pool.runInteraction( + "get_rooms_for_local_user_where_membership_is", + self._get_rooms_for_local_user_where_membership_is_txn, + user_id, + membership_list, + ) + + # Now we filter out forgotten rooms + forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) + return [room for room in rooms if room.room_id not in forgotten_rooms] + + def _get_rooms_for_local_user_where_membership_is_txn( + self, txn, user_id, membership_list + ): + # Paranoia check. + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r" + % (user_id,), + ) + + clause, args = make_in_list_sql_clause( + self.database_engine, "c.membership", membership_list + ) + + sql = """ + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + FROM local_current_membership AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + user_id = ? + AND %s + """ % ( + clause, + ) + + txn.execute(sql, (user_id, *args)) + results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)] + + return results + + @cached(max_entries=500000, iterable=True) + def get_rooms_for_user_with_stream_ordering(self, user_id): + """Returns a set of room_ids the user is currently joined to. + + If a remote user only returns rooms this server is currently + participating in. + + Args: + user_id (str) + + Returns: + Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns + the rooms the user is in currently, along with the stream ordering + of the most recent join for that user and room. + """ + return self.db_pool.runInteraction( + "get_rooms_for_user_with_stream_ordering", + self._get_rooms_for_user_with_stream_ordering_txn, + user_id, + ) + + def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): + # We use `current_state_events` here and not `local_current_membership` + # as a) this gets called with remote users and b) this only gets called + # for rooms the server is participating in. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT room_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND c.membership = ? + """ + else: + sql = """ + SELECT room_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND m.membership = ? + """ + + txn.execute(sql, (user_id, Membership.JOIN)) + results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + + return results + + async def get_users_server_still_shares_room_with( + self, user_ids: Collection[str] + ) -> Set[str]: + """Given a list of users return the set that the server still share a + room with. + """ + + if not user_ids: + return set() + + def _get_users_server_still_shares_room_with_txn(txn): + sql = """ + SELECT state_key FROM current_state_events + WHERE + type = 'm.room.member' + AND membership = 'join' + AND %s + GROUP BY state_key + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "state_key", user_ids + ) + + txn.execute(sql % (clause,), args) + + return {row[0] for row in txn} + + return await self.db_pool.runInteraction( + "get_users_server_still_shares_room_with", + _get_users_server_still_shares_room_with_txn, + ) + + @defer.inlineCallbacks + def get_rooms_for_user(self, user_id, on_invalidate=None): + """Returns a set of room_ids the user is currently joined to. + + If a remote user only returns rooms this server is currently + participating in. + """ + rooms = yield self.get_rooms_for_user_with_stream_ordering( + user_id, on_invalidate=on_invalidate + ) + return frozenset(r.room_id for r in rooms) + + @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) + def get_users_who_share_room_with_user(self, user_id, cache_context): + """Returns the set of users who share a room with `user_id` + """ + room_ids = yield self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate + ) + + user_who_share_room = set() + for room_id in room_ids: + user_ids = yield self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + user_who_share_room.update(user_ids) + + return user_who_share_room + + @defer.inlineCallbacks + def get_joined_users_from_context(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) + result = yield self._get_joined_users_from_context( + event.room_id, state_group, current_state_ids, event=event, context=context + ) + return result + + @defer.inlineCallbacks + def get_joined_users_from_state(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + with Measure(self._clock, "get_joined_users_from_state"): + return ( + yield self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry + ) + ) + + @cachedInlineCallbacks( + num_args=2, cache_context=True, iterable=True, max_entries=100000 + ) + def _get_joined_users_from_context( + self, + room_id, + state_group, + current_state_ids, + cache_context, + event=None, + context=None, + ): + # We don't use `state_group`, it's there so that we can cache based + # on it. However, it's important that it's never None, since two current_states + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + users_in_room = {} + member_event_ids = [ + e_id + for key, e_id in current_state_ids.items() + if key[0] == EventTypes.Member + ] + + if context is not None: + # If we have a context with a delta from a previous state group, + # check if we also have the result from the previous group in cache. + # If we do then we can reuse that result and simply update it with + # any membership changes in `delta_ids` + if context.prev_group and context.delta_ids: + prev_res = self._get_joined_users_from_context.cache.get( + (room_id, context.prev_group), None + ) + if prev_res and isinstance(prev_res, dict): + users_in_room = dict(prev_res) + member_event_ids = [ + e_id + for key, e_id in context.delta_ids.items() + if key[0] == EventTypes.Member + ] + for etype, state_key in context.delta_ids: + if etype == EventTypes.Member: + users_in_room.pop(state_key, None) + + # We check if we have any of the member event ids in the event cache + # before we ask the DB + + # We don't update the event cache hit ratio as it completely throws off + # the hit ratio counts. After all, we don't populate the cache if we + # miss it here + event_map = self._get_events_from_cache( + member_event_ids, allow_rejected=False, update_metrics=False + ) + + missing_member_event_ids = [] + for event_id in member_event_ids: + ev_entry = event_map.get(event_id) + if ev_entry: + if ev_entry.event.membership == Membership.JOIN: + users_in_room[ev_entry.event.state_key] = ProfileInfo( + display_name=ev_entry.event.content.get("displayname", None), + avatar_url=ev_entry.event.content.get("avatar_url", None), + ) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + event_to_memberships = yield self._get_joined_profiles_from_event_ids( + missing_member_event_ids + ) + users_in_room.update((row for row in event_to_memberships.values() if row)) + + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room[event.state_key] = ProfileInfo( + display_name=event.content.get("displayname", None), + avatar_url=event.content.get("avatar_url", None), + ) + + return users_in_room + + @cached(max_entries=10000) + def _get_joined_profile_from_event_id(self, event_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_joined_profile_from_event_id", + list_name="event_ids", + inlineCallbacks=True, + ) + def _get_joined_profiles_from_event_ids(self, event_ids): + """For given set of member event_ids check if they point to a join + event and if so return the associated user and profile info. + + Args: + event_ids (Iterable[str]): The member event IDs to lookup + + Returns: + Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID + to `user_id` and ProfileInfo (or None if not join event). + """ + + rows = yield self.db_pool.simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=event_ids, + retcols=("user_id", "display_name", "avatar_url", "event_id"), + keyvalues={"membership": Membership.JOIN}, + batch_size=500, + desc="_get_membership_from_event_ids", + ) + + return { + row["event_id"]: ( + row["user_id"], + ProfileInfo( + avatar_url=row["avatar_url"], display_name=row["display_name"] + ), + ) + for row in rows + } + + @cachedInlineCallbacks(max_entries=10000) + def is_host_joined(self, room_id, host): + if "%" in host or "_" in host: + raise Exception("Invalid host name") + + sql = """ + SELECT state_key FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (event_id) + WHERE m.membership = 'join' + AND type = 'm.room.member' + AND c.room_id = ? + AND state_key LIKE ? + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self.db_pool.execute( + "is_host_joined", None, sql, room_id, like_clause + ) + + if not rows: + return False + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + return True + + @cachedInlineCallbacks() + def was_host_joined(self, room_id, host): + """Check whether the server is or ever was in the room. + + Args: + room_id (str) + host (str) + + Returns: + Deferred: Resolves to True if the host is/was in the room, otherwise + False. + """ + if "%" in host or "_" in host: + raise Exception("Invalid host name") + + sql = """ + SELECT user_id FROM room_memberships + WHERE room_id = ? + AND user_id LIKE ? + AND membership = 'join' + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self.db_pool.execute( + "was_host_joined", None, sql, room_id, like_clause + ) + + if not rows: + return False + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + return True + + @defer.inlineCallbacks + def get_joined_hosts(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + with Measure(self._clock, "get_joined_hosts"): + return ( + yield self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry + ) + ) + + @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) + # @defer.inlineCallbacks + def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + cache = yield self._get_joined_hosts_cache(room_id) + joined_hosts = yield cache.get_destinations(state_entry) + + return joined_hosts + + @cached(max_entries=10000) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + @cachedInlineCallbacks(num_args=2) + def did_forget(self, user_id, room_id): + """Returns whether user_id has elected to discard history for room_id. + + Returns False if they have since re-joined.""" + + def f(txn): + sql = ( + "SELECT" + " COUNT(*)" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " forgotten = 0" + ) + txn.execute(sql, (user_id, room_id)) + rows = txn.fetchall() + return rows[0][0] + + count = yield self.db_pool.runInteraction("did_forget_membership", f) + return count == 0 + + @cached() + def get_forgotten_rooms_for_user(self, user_id): + """Gets all rooms the user has forgotten. + + Args: + user_id (str) + + Returns: + Deferred[set[str]] + """ + + def _get_forgotten_rooms_for_user_txn(txn): + # This is a slightly convoluted query that first looks up all rooms + # that the user has forgotten in the past, then rechecks that list + # to see if any have subsequently been updated. This is done so that + # we can use a partial index on `forgotten = 1` on the assumption + # that few users will actually forget many rooms. + # + # Note that a room is considered "forgotten" if *all* membership + # events for that user and room have the forgotten field set (as + # when a user forgets a room we update all rows for that user and + # room, not just the current one). + sql = """ + SELECT room_id, ( + SELECT count(*) FROM room_memberships + WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0 + ) AS count + FROM room_memberships AS m + WHERE user_id = ? AND forgotten = 1 + GROUP BY room_id, user_id; + """ + txn.execute(sql, (user_id,)) + return {row[0] for row in txn if row[1] == 0} + + return self.db_pool.runInteraction( + "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn + ) + + @defer.inlineCallbacks + def get_rooms_user_has_been_in(self, user_id): + """Get all rooms that the user has ever been in. + + Args: + user_id (str) + + Returns: + Deferred[set[str]]: Set of room IDs. + """ + + room_ids = yield self.db_pool.simple_select_onecol( + table="room_memberships", + keyvalues={"membership": Membership.JOIN, "user_id": user_id}, + retcol="room_id", + desc="get_rooms_user_has_been_in", + ) + + return set(room_ids) + + def get_membership_from_event_ids( + self, member_event_ids: Iterable[str] + ) -> List[dict]: + """Get user_id and membership of a set of event IDs. + """ + + return self.db_pool.simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=("user_id", "membership", "event_id"), + keyvalues={}, + batch_size=500, + desc="get_membership_from_event_ids", + ) + + async def is_local_host_in_room_ignoring_users( + self, room_id: str, ignore_users: Collection[str] + ) -> bool: + """Check if there are any local users, excluding those in the given + list, in the room. + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "user_id", ignore_users + ) + + sql = """ + SELECT 1 FROM local_current_membership + WHERE + room_id = ? AND membership = ? + AND NOT (%s) + LIMIT 1 + """ % ( + clause, + ) + + def _is_local_host_in_room_ignoring_users_txn(txn): + txn.execute(sql, (room_id, Membership.JOIN, *args)) + + return bool(txn.fetchone()) + + return await self.db_pool.runInteraction( + "is_local_host_in_room_ignoring_users", + _is_local_host_in_room_ignoring_users_txn, + ) + + +class RoomMemberBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db_pool.updates.register_background_update_handler( + _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile + ) + self.db_pool.updates.register_background_update_handler( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, + self._background_current_state_membership, + ) + self.db_pool.updates.register_background_index_update( + "room_membership_forgotten_idx", + index_name="room_memberships_user_room_forgotten", + table="room_memberships", + columns=["user_id", "room_id"], + where_clause="forgotten = 1", + ) + + @defer.inlineCallbacks + def _background_add_membership_profile(self, progress, batch_size): + target_min_stream_id = progress.get( + "target_min_stream_id_inclusive", self._min_stream_order_on_start + ) + max_stream_id = progress.get( + "max_stream_id_exclusive", self._stream_order_on_start + 1 + ) + + INSERT_CLUMP_SIZE = 1000 + + def add_membership_profile_txn(txn): + sql = """ + SELECT stream_ordering, event_id, events.room_id, event_json.json + FROM events + INNER JOIN event_json USING (event_id) + INNER JOIN room_memberships USING (event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND type = 'm.room.member' + ORDER BY stream_ordering DESC + LIMIT ? + """ + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = self.db_pool.cursor_to_dict(txn) + if not rows: + return 0 + + min_stream_id = rows[-1]["stream_ordering"] + + to_update = [] + for row in rows: + event_id = row["event_id"] + room_id = row["room_id"] + try: + event_json = db_to_json(row["json"]) + content = event_json["content"] + except Exception: + continue + + display_name = content.get("displayname", None) + avatar_url = content.get("avatar_url", None) + + if display_name or avatar_url: + to_update.append((display_name, avatar_url, event_id, room_id)) + + to_update_sql = """ + UPDATE room_memberships SET display_name = ?, avatar_url = ? + WHERE event_id = ? AND room_id = ? + """ + for index in range(0, len(to_update), INSERT_CLUMP_SIZE): + clump = to_update[index : index + INSERT_CLUMP_SIZE] + txn.executemany(to_update_sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self.db_pool.updates._background_update_progress_txn( + txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress + ) + + return len(rows) + + result = yield self.db_pool.runInteraction( + _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn + ) + + if not result: + yield self.db_pool.updates._end_background_update( + _MEMBERSHIP_PROFILE_UPDATE_NAME + ) + + return result + + @defer.inlineCallbacks + def _background_current_state_membership(self, progress, batch_size): + """Update the new membership column on current_state_events. + + This works by iterating over all rooms in alphebetical order. + """ + + def _background_current_state_membership_txn(txn, last_processed_room): + processed = 0 + while processed < batch_size: + txn.execute( + """ + SELECT MIN(room_id) FROM current_state_events WHERE room_id > ? + """, + (last_processed_room,), + ) + row = txn.fetchone() + if not row or not row[0]: + return processed, True + + (next_room,) = row + + sql = """ + UPDATE current_state_events + SET membership = ( + SELECT membership FROM room_memberships + WHERE event_id = current_state_events.event_id + ) + WHERE room_id = ? + """ + txn.execute(sql, (next_room,)) + processed += txn.rowcount + + last_processed_room = next_room + + self.db_pool.updates._background_update_progress_txn( + txn, + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, + {"last_processed_room": last_processed_room}, + ) + + return processed, False + + # If we haven't got a last processed room then just use the empty + # string, which will compare before all room IDs correctly. + last_processed_room = progress.get("last_processed_room", "") + + row_count, finished = yield self.db_pool.runInteraction( + "_background_current_state_membership_update", + _background_current_state_membership_txn, + last_processed_room, + ) + + if finished: + yield self.db_pool.updates._end_background_update( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME + ) + + return row_count + + +class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RoomMemberStore, self).__init__(database, db_conn, hs) + + def forget(self, user_id, room_id): + """Indicate that user_id wishes to discard history for room_id.""" + + def f(txn): + sql = ( + "UPDATE" + " room_memberships" + " SET" + " forgotten = 1" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + ) + txn.execute(sql, (user_id, room_id)) + + self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) + self._invalidate_cache_and_stream( + txn, self.get_forgotten_rooms_for_user, (user_id,) + ) + + return self.db_pool.runInteraction("forget_membership", f) + + +class _JoinedHostsCache(object): + """Cache for joined hosts in a room that is optimised to handle updates + via state deltas. + """ + + def __init__(self, store, room_id): + self.store = store + self.room_id = room_id + + self.hosts_to_joined_users = {} + + self.state_group = object() + + self.linearizer = Linearizer("_JoinedHostsCache") + + self._len = 0 + + @defer.inlineCallbacks + def get_destinations(self, state_entry): + """Get set of destinations for a state entry + + Args: + state_entry(synapse.state._StateCacheEntry) + """ + if state_entry.state_group == self.state_group: + return frozenset(self.hosts_to_joined_users) + + with (yield self.linearizer.queue(())): + if state_entry.state_group == self.state_group: + pass + elif state_entry.prev_group == self.state_group: + for (typ, state_key), event_id in state_entry.delta_ids.items(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = self.hosts_to_joined_users.setdefault(host, set()) + + event = yield self.store.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + self.hosts_to_joined_users.pop(host, None) + else: + joined_users = yield self.store.get_joined_users_from_state( + self.room_id, state_entry + ) + + self.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + self.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + self.state_group = state_entry.state_group + else: + self.state_group = object() + self._len = sum(len(v) for v in self.hosts_to_joined_users.values()) + return frozenset(self.hosts_to_joined_users) + + def __len__(self): + return self._len diff --git a/synapse/storage/databases/main/schema/delta/12/v12.sql b/synapse/storage/databases/main/schema/delta/12/v12.sql new file mode 100644 index 0000000000..5964c5aaac --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/12/v12.sql @@ -0,0 +1,63 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + UNIQUE (event_id) +); + +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey VARBINARY(512) NOT NULL, + ts BIGINT UNSIGNED NOT NULL, + lang VARCHAR(8), + data LONGBLOB, + last_token TEXT, + last_success BIGINT UNSIGNED, + failing_since BIGINT UNSIGNED, + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class TINYINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); + +CREATE TABLE IF NOT EXISTS user_filters( + user_id TEXT, + filter_id BIGINT UNSIGNED, + filter_json LONGBLOB +); + +CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( + user_id, filter_id +); diff --git a/synapse/storage/databases/main/schema/delta/13/v13.sql b/synapse/storage/databases/main/schema/delta/13/v13.sql new file mode 100644 index 0000000000..f8649e5d99 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/13/v13.sql @@ -0,0 +1,19 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* We used to create a tables called application_services and + * application_services_regex, but these are no longer used and are removed in + * delta 54. + */ diff --git a/synapse/storage/databases/main/schema/delta/14/v14.sql b/synapse/storage/databases/main/schema/delta/14/v14.sql new file mode 100644 index 0000000000..a831920da6 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/14/v14.sql @@ -0,0 +1,23 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS push_rules_enable ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + enabled TINYINT, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_enable_user_name on push_rules_enable (user_name); diff --git a/synapse/storage/databases/main/schema/delta/15/appservice_txns.sql b/synapse/storage/databases/main/schema/delta/15/appservice_txns.sql new file mode 100644 index 0000000000..e4f5e76aec --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/15/appservice_txns.sql @@ -0,0 +1,31 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS application_services_state( + as_id TEXT PRIMARY KEY, + state VARCHAR(5), + last_txn INTEGER +); + +CREATE TABLE IF NOT EXISTS application_services_txns( + as_id TEXT NOT NULL, + txn_id INTEGER NOT NULL, + event_ids TEXT NOT NULL, + UNIQUE(as_id, txn_id) +); + +CREATE INDEX IF NOT EXISTS application_services_txns_id ON application_services_txns ( + as_id +); diff --git a/synapse/storage/databases/main/schema/delta/15/presence_indices.sql b/synapse/storage/databases/main/schema/delta/15/presence_indices.sql new file mode 100644 index 0000000000..6b8d0f1ca7 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/15/presence_indices.sql @@ -0,0 +1,2 @@ + +CREATE INDEX IF NOT EXISTS presence_list_user_id ON presence_list (user_id); diff --git a/synapse/storage/databases/main/schema/delta/15/v15.sql b/synapse/storage/databases/main/schema/delta/15/v15.sql new file mode 100644 index 0000000000..9523d2bcc3 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/15/v15.sql @@ -0,0 +1,24 @@ +-- Drop, copy & recreate pushers table to change unique key +-- Also add access_token column at the same time +CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey bytea NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data bytea, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey) +); +INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since) + SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers; +DROP TABLE pushers; +ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/databases/main/schema/delta/16/events_order_index.sql b/synapse/storage/databases/main/schema/delta/16/events_order_index.sql new file mode 100644 index 0000000000..a48f215170 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/16/events_order_index.sql @@ -0,0 +1,4 @@ +CREATE INDEX events_order ON events (topological_ordering, stream_ordering); +CREATE INDEX events_order_room ON events ( + room_id, topological_ordering, stream_ordering +); diff --git a/synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql new file mode 100644 index 0000000000..7a15265cb1 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql @@ -0,0 +1,2 @@ +CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id + ON remote_media_cache_thumbnails (media_id); \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql b/synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql new file mode 100644 index 0000000000..65c97b5e2f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql @@ -0,0 +1,9 @@ + + +DELETE FROM event_to_state_groups WHERE state_group not in ( + SELECT MAX(state_group) FROM event_to_state_groups GROUP BY event_id +); + +DELETE FROM event_to_state_groups WHERE rowid not in ( + SELECT MIN(rowid) FROM event_to_state_groups GROUP BY event_id +); diff --git a/synapse/storage/databases/main/schema/delta/16/room_alias_index.sql b/synapse/storage/databases/main/schema/delta/16/room_alias_index.sql new file mode 100644 index 0000000000..f82486132b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/16/room_alias_index.sql @@ -0,0 +1,3 @@ + +CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id); +CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias); diff --git a/synapse/storage/databases/main/schema/delta/16/unique_constraints.sql b/synapse/storage/databases/main/schema/delta/16/unique_constraints.sql new file mode 100644 index 0000000000..5b8de52c33 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/16/unique_constraints.sql @@ -0,0 +1,72 @@ + +-- We can use SQLite features here, since other db support was only added in v16 + +-- +DELETE FROM current_state_events WHERE rowid not in ( + SELECT MIN(rowid) FROM current_state_events GROUP BY event_id +); + +DROP INDEX IF EXISTS current_state_events_event_id; +CREATE UNIQUE INDEX current_state_events_event_id ON current_state_events(event_id); + +-- +DELETE FROM room_memberships WHERE rowid not in ( + SELECT MIN(rowid) FROM room_memberships GROUP BY event_id +); + +DROP INDEX IF EXISTS room_memberships_event_id; +CREATE UNIQUE INDEX room_memberships_event_id ON room_memberships(event_id); + +-- +DELETE FROM topics WHERE rowid not in ( + SELECT MIN(rowid) FROM topics GROUP BY event_id +); + +DROP INDEX IF EXISTS topics_event_id; +CREATE UNIQUE INDEX topics_event_id ON topics(event_id); + +-- +DELETE FROM room_names WHERE rowid not in ( + SELECT MIN(rowid) FROM room_names GROUP BY event_id +); + +DROP INDEX IF EXISTS room_names_id; +CREATE UNIQUE INDEX room_names_id ON room_names(event_id); + +-- +DELETE FROM presence WHERE rowid not in ( + SELECT MIN(rowid) FROM presence GROUP BY user_id +); + +DROP INDEX IF EXISTS presence_id; +CREATE UNIQUE INDEX presence_id ON presence(user_id); + +-- +DELETE FROM presence_allow_inbound WHERE rowid not in ( + SELECT MIN(rowid) FROM presence_allow_inbound + GROUP BY observed_user_id, observer_user_id +); + +DROP INDEX IF EXISTS presence_allow_inbound_observers; +CREATE UNIQUE INDEX presence_allow_inbound_observers ON presence_allow_inbound( + observed_user_id, observer_user_id +); + +-- +DELETE FROM presence_list WHERE rowid not in ( + SELECT MIN(rowid) FROM presence_list + GROUP BY user_id, observed_user_id +); + +DROP INDEX IF EXISTS presence_list_observers; +CREATE UNIQUE INDEX presence_list_observers ON presence_list( + user_id, observed_user_id +); + +-- +DELETE FROM room_aliases WHERE rowid not in ( + SELECT MIN(rowid) FROM room_aliases GROUP BY room_alias +); + +DROP INDEX IF EXISTS room_aliases_id; +CREATE INDEX room_aliases_id ON room_aliases(room_id); diff --git a/synapse/storage/databases/main/schema/delta/16/users.sql b/synapse/storage/databases/main/schema/delta/16/users.sql new file mode 100644 index 0000000000..cd0709250d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/16/users.sql @@ -0,0 +1,56 @@ +-- Convert `access_tokens`.user from rowids to user strings. +-- MUST BE DONE BEFORE REMOVING ID COLUMN FROM USERS TABLE BELOW +CREATE TABLE IF NOT EXISTS new_access_tokens( + id BIGINT UNSIGNED PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + last_used BIGINT UNSIGNED, + UNIQUE(token) +); + +INSERT INTO new_access_tokens + SELECT a.id, u.name, a.device_id, a.token, a.last_used + FROM access_tokens as a + INNER JOIN users as u ON u.id = a.user_id; + +DROP TABLE access_tokens; + +ALTER TABLE new_access_tokens RENAME TO access_tokens; + +-- Remove ID column from `users` table +CREATE TABLE IF NOT EXISTS new_users( + name TEXT, + password_hash TEXT, + creation_ts BIGINT UNSIGNED, + admin BOOL DEFAULT 0 NOT NULL, + UNIQUE(name) +); + +INSERT INTO new_users SELECT name, password_hash, creation_ts, admin FROM users; + +DROP TABLE users; + +ALTER TABLE new_users RENAME TO users; + + +-- Remove UNIQUE constraint from `user_ips` table +CREATE TABLE IF NOT EXISTS new_user_ips ( + user_id TEXT NOT NULL, + access_token TEXT NOT NULL, + device_id TEXT, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + last_seen BIGINT UNSIGNED NOT NULL +); + +INSERT INTO new_user_ips + SELECT user, access_token, device_id, ip, user_agent, last_seen FROM user_ips; + +DROP TABLE user_ips; + +ALTER TABLE new_user_ips RENAME TO user_ips; + +CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user_id); +CREATE INDEX IF NOT EXISTS user_ips_user_ip ON user_ips(user_id, access_token, ip); + diff --git a/synapse/storage/databases/main/schema/delta/17/drop_indexes.sql b/synapse/storage/databases/main/schema/delta/17/drop_indexes.sql new file mode 100644 index 0000000000..7c9a90e27f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/17/drop_indexes.sql @@ -0,0 +1,18 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS sent_transaction_dest; +DROP INDEX IF EXISTS sent_transaction_sent; +DROP INDEX IF EXISTS user_ips_user; diff --git a/synapse/storage/databases/main/schema/delta/17/server_keys.sql b/synapse/storage/databases/main/schema/delta/17/server_keys.sql new file mode 100644 index 0000000000..70b247a06b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/17/server_keys.sql @@ -0,0 +1,24 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS server_keys_json ( + server_name TEXT, -- Server name. + key_id TEXT, -- Requested key id. + from_server TEXT, -- Which server the keys were fetched from. + ts_added_ms INTEGER, -- When the keys were fetched + ts_valid_until_ms INTEGER, -- When this version of the keys exipires. + key_json bytea, -- JSON certificate for the remote server. + CONSTRAINT uniqueness UNIQUE (server_name, key_id, from_server) +); diff --git a/synapse/storage/databases/main/schema/delta/17/user_threepids.sql b/synapse/storage/databases/main/schema/delta/17/user_threepids.sql new file mode 100644 index 0000000000..c17715ac80 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/17/user_threepids.sql @@ -0,0 +1,9 @@ +CREATE TABLE user_threepids ( + user_id TEXT NOT NULL, + medium TEXT NOT NULL, + address TEXT NOT NULL, + validated_at BIGINT NOT NULL, + added_at BIGINT NOT NULL, + CONSTRAINT user_medium_address UNIQUE (user_id, medium, address) +); +CREATE INDEX user_threepids_user_id ON user_threepids(user_id); diff --git a/synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql new file mode 100644 index 0000000000..6e0871c92b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql @@ -0,0 +1,32 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS new_server_keys_json ( + server_name TEXT NOT NULL, -- Server name. + key_id TEXT NOT NULL, -- Requested key id. + from_server TEXT NOT NULL, -- Which server the keys were fetched from. + ts_added_ms BIGINT NOT NULL, -- When the keys were fetched + ts_valid_until_ms BIGINT NOT NULL, -- When this version of the keys exipires. + key_json bytea NOT NULL, -- JSON certificate for the remote server. + CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server) +); + +INSERT INTO new_server_keys_json + SELECT server_name, key_id, from_server,ts_added_ms, ts_valid_until_ms, key_json FROM server_keys_json ; + +DROP TABLE server_keys_json; + +ALTER TABLE new_server_keys_json RENAME TO server_keys_json; diff --git a/synapse/storage/databases/main/schema/delta/19/event_index.sql b/synapse/storage/databases/main/schema/delta/19/event_index.sql new file mode 100644 index 0000000000..18b97b4332 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/19/event_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE INDEX events_order_topo_stream_room ON events( + topological_ordering, stream_ordering, room_id +); diff --git a/synapse/storage/databases/main/schema/delta/20/dummy.sql b/synapse/storage/databases/main/schema/delta/20/dummy.sql new file mode 100644 index 0000000000..e0ac49d1ec --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/20/dummy.sql @@ -0,0 +1 @@ +SELECT 1; diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py new file mode 100644 index 0000000000..3edfcfd783 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/20/pushers.py @@ -0,0 +1,88 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Main purpose of this upgrade is to change the unique key on the +pushers table again (it was missed when the v16 full schema was +made) but this also changes the pushkey and data columns to text. +When selecting a bytea column into a text column, postgres inserts +the hex encoded data, and there's no portable way of getting the +UTF-8 bytes, so we have to do it in Python. +""" + +import logging + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + logger.info("Porting pushers table...") + cur.execute( + """ + CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data TEXT, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) + ) + """ + ) + cur.execute( + """SELECT + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + FROM pushers + """ + ) + count = 0 + for row in cur.fetchall(): + row = list(row) + row[8] = bytes(row[8]).decode("utf-8") + row[11] = bytes(row[11]).decode("utf-8") + cur.execute( + database_engine.convert_param_style( + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + ) values (%s)""" + % (",".join(["?" for _ in range(len(row))])) + ), + row, + ) + count += 1 + cur.execute("DROP TABLE pushers") + cur.execute("ALTER TABLE pushers2 RENAME TO pushers") + logger.info("Moved %d pushers to new table", count) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql b/synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql new file mode 100644 index 0000000000..4c2fb20b77 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql @@ -0,0 +1,34 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS e2e_device_keys_json ( + user_id TEXT NOT NULL, -- The user these keys are for. + device_id TEXT NOT NULL, -- Which of the user's devices these keys are for. + ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded. + key_json TEXT NOT NULL, -- The keys for the device as a JSON blob. + CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) +); + + +CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json ( + user_id TEXT NOT NULL, -- The user this one-time key is for. + device_id TEXT NOT NULL, -- The device this one-time key is for. + algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. + key_json TEXT NOT NULL, -- The key as a JSON blob. + CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) +); diff --git a/synapse/storage/databases/main/schema/delta/21/receipts.sql b/synapse/storage/databases/main/schema/delta/21/receipts.sql new file mode 100644 index 0000000000..d070845477 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/21/receipts.sql @@ -0,0 +1,38 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS receipts_graph( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_ids TEXT NOT NULL, + data TEXT NOT NULL, + CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id) +); + +CREATE TABLE IF NOT EXISTS receipts_linearized ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + data TEXT NOT NULL, + CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id) +); + +CREATE INDEX receipts_linearized_id ON receipts_linearized( + stream_id +); diff --git a/synapse/storage/databases/main/schema/delta/22/receipts_index.sql b/synapse/storage/databases/main/schema/delta/22/receipts_index.sql new file mode 100644 index 0000000000..bfc0b3bcaa --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/22/receipts_index.sql @@ -0,0 +1,22 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( + room_id, stream_id +); diff --git a/synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql b/synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql new file mode 100644 index 0000000000..87edfa454c --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql @@ -0,0 +1,19 @@ +CREATE TABLE IF NOT EXISTS user_threepids2 ( + user_id TEXT NOT NULL, + medium TEXT NOT NULL, + address TEXT NOT NULL, + validated_at BIGINT NOT NULL, + added_at BIGINT NOT NULL, + CONSTRAINT medium_address UNIQUE (medium, address) +); + +INSERT INTO user_threepids2 + SELECT * FROM user_threepids WHERE added_at IN ( + SELECT max(added_at) FROM user_threepids GROUP BY medium, address + ) +; + +DROP TABLE user_threepids; +ALTER TABLE user_threepids2 RENAME TO user_threepids; + +CREATE INDEX user_threepids_user_id ON user_threepids(user_id); diff --git a/synapse/storage/databases/main/schema/delta/24/stats_reporting.sql b/synapse/storage/databases/main/schema/delta/24/stats_reporting.sql new file mode 100644 index 0000000000..acea7483bd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/24/stats_reporting.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + /* We used to create a table called stats_reporting, but this is no longer + * used and is removed in delta 54. + */ \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py new file mode 100644 index 0000000000..ee675e71ff --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/25/fts.py @@ -0,0 +1,80 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging + +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +POSTGRES_TABLE = """ +CREATE TABLE IF NOT EXISTS event_search ( + event_id TEXT, + room_id TEXT, + sender TEXT, + key TEXT, + vector tsvector +); + +CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); +CREATE INDEX event_search_ev_idx ON event_search(event_id); +CREATE INDEX event_search_ev_ridx ON event_search(room_id); +""" + + +SQLITE_TABLE = ( + "CREATE VIRTUAL TABLE event_search" + " USING fts4 ( event_id, room_id, sender, key, value )" +) + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + for statement in get_statements(POSTGRES_TABLE.splitlines()): + cur.execute(statement) + elif isinstance(database_engine, Sqlite3Engine): + cur.execute(SQLITE_TABLE) + else: + raise Exception("Unrecognized database engine") + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + } + progress_json = json.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_search", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/25/guest_access.sql b/synapse/storage/databases/main/schema/delta/25/guest_access.sql new file mode 100644 index 0000000000..1ea389b471 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/25/guest_access.sql @@ -0,0 +1,25 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This is a manual index of guest_access content of state events, + * so that we can join on them in SELECT statements. + */ +CREATE TABLE IF NOT EXISTS guest_access( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + guest_access TEXT NOT NULL, + UNIQUE (event_id) +); diff --git a/synapse/storage/databases/main/schema/delta/25/history_visibility.sql b/synapse/storage/databases/main/schema/delta/25/history_visibility.sql new file mode 100644 index 0000000000..f468fc1897 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/25/history_visibility.sql @@ -0,0 +1,25 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This is a manual index of history_visibility content of state events, + * so that we can join on them in SELECT statements. + */ +CREATE TABLE IF NOT EXISTS history_visibility( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + history_visibility TEXT NOT NULL, + UNIQUE (event_id) +); diff --git a/synapse/storage/databases/main/schema/delta/25/tags.sql b/synapse/storage/databases/main/schema/delta/25/tags.sql new file mode 100644 index 0000000000..7a32ce68e4 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/25/tags.sql @@ -0,0 +1,38 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS room_tags( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + tag TEXT NOT NULL, -- The name of the tag. + content TEXT NOT NULL, -- The JSON content of the tag. + CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) +); + +CREATE TABLE IF NOT EXISTS room_tags_revisions ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + stream_id BIGINT NOT NULL, -- The current version of the room tags. + CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) +); + +CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0); diff --git a/synapse/storage/databases/main/schema/delta/26/account_data.sql b/synapse/storage/databases/main/schema/delta/26/account_data.sql new file mode 100644 index 0000000000..e395de2b5e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/26/account_data.sql @@ -0,0 +1,17 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id; diff --git a/synapse/storage/databases/main/schema/delta/27/account_data.sql b/synapse/storage/databases/main/schema/delta/27/account_data.sql new file mode 100644 index 0000000000..bf0558b5b3 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/27/account_data.sql @@ -0,0 +1,36 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS account_data( + user_id TEXT NOT NULL, + account_data_type TEXT NOT NULL, -- The type of the account_data. + stream_id BIGINT NOT NULL, -- The version of the account_data. + content TEXT NOT NULL, -- The JSON content of the account_data + CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) +); + + +CREATE TABLE IF NOT EXISTS room_account_data( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + account_data_type TEXT NOT NULL, -- The type of the account_data. + stream_id BIGINT NOT NULL, -- The version of the account_data. + content TEXT NOT NULL, -- The JSON content of the account_data + CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) +); + + +CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); +CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); diff --git a/synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql b/synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql new file mode 100644 index 0000000000..e2094f37fe --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql @@ -0,0 +1,26 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Keeps track of what rooms users have left and don't want to be able to + * access again. + * + * If all users on this server have left a room, we can delete the room + * entirely. + * + * This column should always contain either 0 or 1. + */ + + ALTER TABLE room_memberships ADD COLUMN forgotten INTEGER DEFAULT 0; diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py new file mode 100644 index 0000000000..b7972cfa8e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/27/ts.py @@ -0,0 +1,59 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging + +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +ALTER_TABLE = ( + "ALTER TABLE events ADD COLUMN origin_server_ts BIGINT;" + "CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);" +) + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(ALTER_TABLE.splitlines()): + cur.execute(statement) + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + } + progress_json = json.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_origin_server_ts", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/28/event_push_actions.sql b/synapse/storage/databases/main/schema/delta/28/event_push_actions.sql new file mode 100644 index 0000000000..4d519849df --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/28/event_push_actions.sql @@ -0,0 +1,27 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_push_actions( + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + profile_tag VARCHAR(32), + actions TEXT NOT NULL, + CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) +); + + +CREATE INDEX event_push_actions_room_id_event_id_user_id_profile_tag on event_push_actions(room_id, event_id, user_id, profile_tag); +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); diff --git a/synapse/storage/databases/main/schema/delta/28/events_room_stream.sql b/synapse/storage/databases/main/schema/delta/28/events_room_stream.sql new file mode 100644 index 0000000000..36609475f1 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/28/events_room_stream.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ +CREATE INDEX events_room_stream on events(room_id, stream_ordering); diff --git a/synapse/storage/databases/main/schema/delta/28/public_roms_index.sql b/synapse/storage/databases/main/schema/delta/28/public_roms_index.sql new file mode 100644 index 0000000000..6c1fd68c5b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/28/public_roms_index.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ +CREATE INDEX public_room_index on rooms(is_public); diff --git a/synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql new file mode 100644 index 0000000000..cb84c69baa --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql @@ -0,0 +1,22 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ +CREATE INDEX receipts_linearized_user ON receipts_linearized( + user_id +); diff --git a/synapse/storage/databases/main/schema/delta/28/upgrade_times.sql b/synapse/storage/databases/main/schema/delta/28/upgrade_times.sql new file mode 100644 index 0000000000..3e4a9ab455 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/28/upgrade_times.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Stores the timestamp when a user upgraded from a guest to a full user, if + * that happened. + */ + +ALTER TABLE users ADD COLUMN upgrade_ts BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/28/users_is_guest.sql b/synapse/storage/databases/main/schema/delta/28/users_is_guest.sql new file mode 100644 index 0000000000..21d2b420bf --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/28/users_is_guest.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE users ADD is_guest SMALLINT DEFAULT 0 NOT NULL; +/* + * NB: any guest users created between 27 and 28 will be incorrectly + * marked as not guests: we don't bother to fill these in correctly + * because guest access is not really complete in 27 anyway so it's + * very unlikley there will be any guest users created. + */ diff --git a/synapse/storage/databases/main/schema/delta/29/push_actions.sql b/synapse/storage/databases/main/schema/delta/29/push_actions.sql new file mode 100644 index 0000000000..84b21cf813 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/29/push_actions.sql @@ -0,0 +1,35 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT; +ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT; +ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT; +ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT; + +UPDATE event_push_actions SET stream_ordering = ( + SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id +), topological_ordering = ( + SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id +); + +UPDATE event_push_actions SET notif = 1, highlight = 0; + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ +CREATE INDEX event_push_actions_rm_tokens on event_push_actions( + user_id, room_id, topological_ordering, stream_ordering +); diff --git a/synapse/storage/databases/main/schema/delta/30/alias_creator.sql b/synapse/storage/databases/main/schema/delta/30/alias_creator.sql new file mode 100644 index 0000000000..c9d0dde638 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/30/alias_creator.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_aliases ADD COLUMN creator TEXT; diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py new file mode 100644 index 0000000000..b42c02710a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/30/as_users.py @@ -0,0 +1,67 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.config.appservice import load_appservices + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + # NULL indicates user was not registered by an appservice. + try: + cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") + except Exception: + # Maybe we already added the column? Hope so... + pass + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + cur.execute("SELECT name FROM users") + rows = cur.fetchall() + + config_files = [] + try: + config_files = config.app_service_config_files + except AttributeError: + logger.warning("Could not get app_service_config_files from config") + pass + + appservices = load_appservices(config.server_name, config_files) + + owned = {} + + for row in rows: + user_id = row[0] + for appservice in appservices: + if appservice.is_exclusive_user(user_id): + if user_id in owned.keys(): + logger.error( + "user_id %s was owned by more than one application" + " service (IDs %s and %s); assigning arbitrarily to %s" + % (user_id, owned[user_id], appservice.id, owned[user_id]) + ) + owned.setdefault(appservice.id, []).append(user_id) + + for as_id, user_ids in owned.items(): + n = 100 + user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) + for chunk in user_chunks: + cur.execute( + database_engine.convert_param_style( + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),) + ), + [as_id] + chunk, + ) diff --git a/synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql b/synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql new file mode 100644 index 0000000000..712c454aa1 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql @@ -0,0 +1,25 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS deleted_pushers( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL, + /* We only track the most recent delete for each app_id, pushkey and user_id. */ + UNIQUE (app_id, pushkey, user_id) +); + +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); diff --git a/synapse/storage/databases/main/schema/delta/30/presence_stream.sql b/synapse/storage/databases/main/schema/delta/30/presence_stream.sql new file mode 100644 index 0000000000..606bbb037d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/30/presence_stream.sql @@ -0,0 +1,30 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + CREATE TABLE presence_stream( + stream_id BIGINT, + user_id TEXT, + state TEXT, + last_active_ts BIGINT, + last_federation_update_ts BIGINT, + last_user_sync_ts BIGINT, + status_msg TEXT, + currently_active BOOLEAN + ); + + CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); + CREATE INDEX presence_stream_user_id ON presence_stream(user_id); + CREATE INDEX presence_stream_state ON presence_stream(state); diff --git a/synapse/storage/databases/main/schema/delta/30/public_rooms.sql b/synapse/storage/databases/main/schema/delta/30/public_rooms.sql new file mode 100644 index 0000000000..f09db4faa6 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/30/public_rooms.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/* This release removes the restriction that published rooms must have an alias, + * so we go back and ensure the only 'public' rooms are ones with an alias. + * We use (1 = 0) and (1 = 1) so that it works in both postgres and sqlite + */ +UPDATE rooms SET is_public = (1 = 0) WHERE is_public = (1 = 1) AND room_id not in ( + SELECT room_id FROM room_aliases +); diff --git a/synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql b/synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql new file mode 100644 index 0000000000..735aa8d5f6 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql @@ -0,0 +1,38 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + +CREATE TABLE push_rules_stream( + stream_id BIGINT NOT NULL, + event_stream_ordering BIGINT NOT NULL, + user_id TEXT NOT NULL, + rule_id TEXT NOT NULL, + op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE" + priority_class SMALLINT, + priority INTEGER, + conditions TEXT, + actions TEXT +); + +-- The extra data for each operation is: +-- * ENABLE, DISABLE, DELETE: [] +-- * ACTIONS: ["actions"] +-- * ADD: ["priority_class", "priority", "actions", "conditions"] + +-- Index for replication queries. +CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); +-- Index for /sync queries. +CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); diff --git a/synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql new file mode 100644 index 0000000000..0dd2f1360c --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Stores guest account access tokens generated for unbound 3pids. +CREATE TABLE threepid_guest_access_tokens( + medium TEXT, -- The medium of the 3pid. Must be "email". + address TEXT, -- The 3pid address. + guest_access_token TEXT, -- The access token for a guest user for this 3pid. + first_inviter TEXT -- User ID of the first user to invite this 3pid to a room. +); + +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); diff --git a/synapse/storage/databases/main/schema/delta/31/invites.sql b/synapse/storage/databases/main/schema/delta/31/invites.sql new file mode 100644 index 0000000000..2c57846d5a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/31/invites.sql @@ -0,0 +1,42 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE local_invites( + stream_id BIGINT NOT NULL, + inviter TEXT NOT NULL, + invitee TEXT NOT NULL, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + locally_rejected TEXT, + replaced_by TEXT +); + +-- Insert all invites for local users into new `invites` table +INSERT INTO local_invites SELECT + stream_ordering as stream_id, + sender as inviter, + state_key as invitee, + event_id, + room_id, + NULL as locally_rejected, + NULL as replaced_by + FROM events + NATURAL JOIN current_state_events + NATURAL JOIN room_memberships + WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); + +CREATE INDEX local_invites_id ON local_invites(stream_id); +CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); diff --git a/synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql new file mode 100644 index 0000000000..9efb4280eb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql @@ -0,0 +1,27 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE local_media_repository_url_cache( + url TEXT, -- the URL being cached + response_code INTEGER, -- the HTTP response code of this download attempt + etag TEXT, -- the etag header of this response + expires INTEGER, -- the number of ms this response was valid for + og TEXT, -- cache of the OG metadata of this URL as JSON + media_id TEXT, -- the media_id, if any, of the URL's content in the repo + download_ts BIGINT -- the timestamp of this download attempt +); + +CREATE INDEX local_media_repository_url_cache_by_url_download_ts + ON local_media_repository_url_cache(url, download_ts); diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py new file mode 100644 index 0000000000..9bb504aad5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/31/pushers.py @@ -0,0 +1,87 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Change the last_token to last_stream_ordering now that pushers no longer +# listen on an event stream but instead select out of the event_push_actions +# table. + + +import logging + +logger = logging.getLogger(__name__) + + +def token_to_stream_ordering(token): + return int(token[1:].split("_")[0]) + + +def run_create(cur, database_engine, *args, **kwargs): + logger.info("Porting pushers table, delta 31...") + cur.execute( + """ + CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data TEXT, + last_stream_ordering INTEGER, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) + ) + """ + ) + cur.execute( + """SELECT + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + FROM pushers + """ + ) + count = 0 + for row in cur.fetchall(): + row = list(row) + row[12] = token_to_stream_ordering(row[12]) + cur.execute( + database_engine.convert_param_style( + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_stream_ordering, last_success, + failing_since + ) values (%s)""" + % (",".join(["?" for _ in range(len(row))])) + ), + row, + ) + count += 1 + cur.execute("DROP TABLE pushers") + cur.execute("ALTER TABLE pushers2 RENAME TO pushers") + logger.info("Moved %d pushers to new table", count) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/31/pushers_index.sql b/synapse/storage/databases/main/schema/delta/31/pushers_index.sql new file mode 100644 index 0000000000..a82add88fd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/31/pushers_index.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ + CREATE INDEX event_push_actions_stream_ordering on event_push_actions( + stream_ordering, user_id + ); diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py new file mode 100644 index 0000000000..63b757ade6 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/31/search_update.py @@ -0,0 +1,64 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +ALTER_TABLE = """ +ALTER TABLE event_search ADD COLUMN origin_server_ts BIGINT; +ALTER TABLE event_search ADD COLUMN stream_ordering BIGINT; +""" + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + for statement in get_statements(ALTER_TABLE.splitlines()): + cur.execute(statement) + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + "have_added_indexes": False, + } + progress_json = json.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_search_order", progress_json)) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/32/events.sql b/synapse/storage/databases/main/schema/delta/32/events.sql new file mode 100644 index 0000000000..1dd0f9e170 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/32/events.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN received_ts BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/32/openid.sql b/synapse/storage/databases/main/schema/delta/32/openid.sql new file mode 100644 index 0000000000..36f37b11c8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/32/openid.sql @@ -0,0 +1,9 @@ + +CREATE TABLE open_id_tokens ( + token TEXT NOT NULL PRIMARY KEY, + ts_valid_until_ms bigint NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (token) +); + +CREATE index open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms); diff --git a/synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql b/synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql new file mode 100644 index 0000000000..d86d30c13c --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE pusher_throttle( + pusher BIGINT NOT NULL, + room_id TEXT NOT NULL, + last_sent_ts BIGINT, + throttle_ms BIGINT, + PRIMARY KEY (pusher, room_id) +); diff --git a/synapse/storage/databases/main/schema/delta/32/remove_indices.sql b/synapse/storage/databases/main/schema/delta/32/remove_indices.sql new file mode 100644 index 0000000000..2de50d408c --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/32/remove_indices.sql @@ -0,0 +1,33 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream +DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room +DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room +DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY +DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY +DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT + +DROP INDEX IF EXISTS st_extrem_id; -- Prefix of UNIQUE CONSTRAINT +DROP INDEX IF EXISTS event_signatures_id; -- Prefix of UNIQUE CONSTRAINT +DROP INDEX IF EXISTS redactions_event_id; -- Duplicate of UNIQUE CONSTRAINT + +-- The following indices were unused +DROP INDEX IF EXISTS remote_media_cache_thumbnails_media_id; +DROP INDEX IF EXISTS evauth_edges_auth_id; +DROP INDEX IF EXISTS presence_stream_state; diff --git a/synapse/storage/databases/main/schema/delta/32/reports.sql b/synapse/storage/databases/main/schema/delta/32/reports.sql new file mode 100644 index 0000000000..d13609776f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/32/reports.sql @@ -0,0 +1,25 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE event_reports( + id BIGINT NOT NULL PRIMARY KEY, + received_ts BIGINT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + reason TEXT, + content TEXT +); diff --git a/synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql new file mode 100644 index 0000000000..61ad3fe3e8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/33/devices.sql b/synapse/storage/databases/main/schema/delta/33/devices.sql new file mode 100644 index 0000000000..eca7268d82 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/33/devices.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE devices ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); diff --git a/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql new file mode 100644 index 0000000000..aa4a3b9f2f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- make sure that we have a device record for each set of E2E keys, so that the +-- user can delete them if they like. +INSERT INTO devices + SELECT user_id, device_id, NULL FROM e2e_device_keys_json; diff --git a/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql new file mode 100644 index 0000000000..6671573398 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a previous version of the "devices_for_e2e_keys" delta set all the device +-- names to "unknown device". This wasn't terribly helpful +UPDATE devices + SET display_name = NULL + WHERE display_name = 'unknown device'; diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py new file mode 100644 index 0000000000..a3e81eeac7 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py @@ -0,0 +1,59 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging + +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +ALTER_TABLE = """ +ALTER TABLE events ADD COLUMN sender TEXT; +ALTER TABLE events ADD COLUMN contains_url BOOLEAN; +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(ALTER_TABLE.splitlines()): + cur.execute(statement) + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + } + progress_json = json.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_fields_sender_url", progress_json)) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py new file mode 100644 index 0000000000..a26057dfb6 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py @@ -0,0 +1,30 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT" + + +def run_create(cur, database_engine, *args, **kwargs): + cur.execute(ALTER_TABLE) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + cur.execute( + database_engine.convert_param_style( + "UPDATE remote_media_cache SET last_access_ts = ?" + ), + (int(time.time() * 1000),), + ) diff --git a/synapse/storage/databases/main/schema/delta/33/user_ips_index.sql b/synapse/storage/databases/main/schema/delta/33/user_ips_index.sql new file mode 100644 index 0000000000..473f75a78e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/33/user_ips_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('user_ips_device_index', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/34/appservice_stream.sql b/synapse/storage/databases/main/schema/delta/34/appservice_stream.sql new file mode 100644 index 0000000000..69e16eda0f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/34/appservice_stream.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS appservice_stream_position( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_ordering BIGINT, + CHECK (Lock='X') +); + +INSERT INTO appservice_stream_position (stream_ordering) + SELECT COALESCE(MAX(stream_ordering), 0) FROM events; diff --git a/synapse/storage/databases/main/schema/delta/34/cache_stream.py b/synapse/storage/databases/main/schema/delta/34/cache_stream.py new file mode 100644 index 0000000000..cf09e43e2b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/34/cache_stream.py @@ -0,0 +1,46 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +# This stream is used to notify replication slaves that some caches have +# been invalidated that they cannot infer from the other streams. +CREATE_TABLE = """ +CREATE TABLE cache_invalidation_stream ( + stream_id BIGINT, + cache_func TEXT, + keys TEXT[], + invalidation_ts BIGINT +); + +CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + for statement in get_statements(CREATE_TABLE.splitlines()): + cur.execute(statement) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/34/device_inbox.sql b/synapse/storage/databases/main/schema/delta/34/device_inbox.sql new file mode 100644 index 0000000000..e68844c74a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/34/device_inbox.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_inbox ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + stream_id BIGINT NOT NULL, + message_json TEXT NOT NULL -- {"type":, "sender":, "content",} +); + +CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); +CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id); diff --git a/synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql b/synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql new file mode 100644 index 0000000000..0d9fe1a99a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name'; +UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; + +DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name'; +UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; diff --git a/synapse/storage/databases/main/schema/delta/34/received_txn_purge.py b/synapse/storage/databases/main/schema/delta/34/received_txn_purge.py new file mode 100644 index 0000000000..67d505e68b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/34/received_txn_purge.py @@ -0,0 +1,32 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute("TRUNCATE received_transactions") + else: + cur.execute("DELETE FROM received_transactions") + + cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)") + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/35/contains_url.sql b/synapse/storage/databases/main/schema/delta/35/contains_url.sql new file mode 100644 index 0000000000..6cd123027b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/35/contains_url.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT into background_updates (update_name, progress_json) + VALUES ('event_contains_url_index', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/35/device_outbox.sql b/synapse/storage/databases/main/schema/delta/35/device_outbox.sql new file mode 100644 index 0000000000..17e6c43105 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/35/device_outbox.sql @@ -0,0 +1,39 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP TABLE IF EXISTS device_federation_outbox; +CREATE TABLE device_federation_outbox ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + queued_ts BIGINT NOT NULL, + messages_json TEXT NOT NULL +); + + +DROP INDEX IF EXISTS device_federation_outbox_destination_id; +CREATE INDEX device_federation_outbox_destination_id + ON device_federation_outbox(destination, stream_id); + + +DROP TABLE IF EXISTS device_federation_inbox; +CREATE TABLE device_federation_inbox ( + origin TEXT NOT NULL, + message_id TEXT NOT NULL, + received_ts BIGINT NOT NULL +); + +DROP INDEX IF EXISTS device_federation_inbox_sender_id; +CREATE INDEX device_federation_inbox_sender_id + ON device_federation_inbox(origin, message_id); diff --git a/synapse/storage/databases/main/schema/delta/35/device_stream_id.sql b/synapse/storage/databases/main/schema/delta/35/device_stream_id.sql new file mode 100644 index 0000000000..7ab7d942e2 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/35/device_stream_id.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_max_stream_id ( + stream_id BIGINT NOT NULL +); + +INSERT INTO device_max_stream_id (stream_id) + SELECT COALESCE(MAX(stream_id), 0) FROM device_inbox; diff --git a/synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql b/synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql new file mode 100644 index 0000000000..2e836d8e9c --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT into background_updates (update_name, progress_json) + VALUES ('epa_highlight_index', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql new file mode 100644 index 0000000000..dd2bf2e28a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql @@ -0,0 +1,33 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE public_room_list_stream ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + visibility BOOLEAN NOT NULL +); + +INSERT INTO public_room_list_stream (stream_id, room_id, visibility) + SELECT 1, room_id, is_public FROM rooms + WHERE is_public = CAST(1 AS BOOLEAN); + +CREATE INDEX public_room_list_stream_idx on public_room_list_stream( + stream_id +); + +CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( + room_id, stream_id +); diff --git a/synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql new file mode 100644 index 0000000000..2b945d8a57 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql @@ -0,0 +1,37 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE stream_ordering_to_exterm ( + stream_ordering BIGINT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) + SELECT stream_ordering, room_id, event_id FROM event_forward_extremities + INNER JOIN ( + SELECT room_id, max(stream_ordering) as stream_ordering FROM events + INNER JOIN event_forward_extremities USING (room_id, event_id) + GROUP BY room_id + ) AS rms USING (room_id); + +CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( + stream_ordering +); + +CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( + room_id, stream_ordering +); diff --git a/synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql b/synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql new file mode 100644 index 0000000000..90d8fd18f9 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql @@ -0,0 +1,26 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Re-add some entries to stream_ordering_to_exterm that were incorrectly deleted +INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) + SELECT + (SELECT stream_ordering FROM events where event_id = e.event_id) AS stream_ordering, + room_id, + event_id + FROM event_forward_extremities AS e + WHERE NOT EXISTS ( + SELECT room_id FROM stream_ordering_to_exterm AS s + WHERE s.room_id = e.room_id + ); diff --git a/synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py b/synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py new file mode 100644 index 0000000000..a377884169 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py @@ -0,0 +1,85 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + +DROP_INDICES = """ +-- We only ever query based on event_id +DROP INDEX IF EXISTS state_events_room_id; +DROP INDEX IF EXISTS state_events_type; +DROP INDEX IF EXISTS state_events_state_key; + +-- room_id is indexed elsewhere +DROP INDEX IF EXISTS current_state_events_room_id; +DROP INDEX IF EXISTS current_state_events_state_key; +DROP INDEX IF EXISTS current_state_events_type; + +DROP INDEX IF EXISTS transactions_have_ref; + +-- (topological_ordering, stream_ordering, room_id) seems like a strange index, +-- and is used incredibly rarely. +DROP INDEX IF EXISTS events_order_topo_stream_room; + +-- an equivalent index to this actually gets re-created in delta 41, because it +-- turned out that deleting it wasn't a great plan :/. In any case, let's +-- delete it here, and delta 41 will create a new one with an added UNIQUE +-- constraint +DROP INDEX IF EXISTS event_search_ev_idx; +""" + +POSTGRES_DROP_CONSTRAINT = """ +ALTER TABLE event_auth DROP CONSTRAINT IF EXISTS event_auth_event_id_auth_id_room_id_key; +""" + +SQLITE_DROP_CONSTRAINT = """ +DROP INDEX IF EXISTS evauth_edges_id; + +CREATE TABLE IF NOT EXISTS event_auth_new( + event_id TEXT NOT NULL, + auth_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +INSERT INTO event_auth_new + SELECT event_id, auth_id, room_id + FROM event_auth; + +DROP TABLE event_auth; + +ALTER TABLE event_auth_new RENAME TO event_auth; + +CREATE INDEX evauth_edges_id ON event_auth(event_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(DROP_INDICES.splitlines()): + cur.execute(statement) + + if isinstance(database_engine, PostgresEngine): + drop_constraint = POSTGRES_DROP_CONSTRAINT + else: + drop_constraint = SQLITE_DROP_CONSTRAINT + + for statement in get_statements(drop_constraint.splitlines()): + cur.execute(statement) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/37/user_threepids.sql b/synapse/storage/databases/main/schema/delta/37/user_threepids.sql new file mode 100644 index 0000000000..cf7a90dd10 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/37/user_threepids.sql @@ -0,0 +1,52 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Update any email addresses that were stored with mixed case into all + * lowercase + */ + + -- There may be "duplicate" emails (with different case) already in the table, + -- so we find them and move all but the most recently used account. + UPDATE user_threepids + SET medium = 'email_old' + WHERE medium = 'email' + AND address IN ( + -- We select all the addresses that are linked to the user_id that is NOT + -- the most recently created. + SELECT u.address + FROM + user_threepids AS u, + -- `duplicate_addresses` is a table of all the email addresses that + -- appear multiple times and when the binding was created + ( + SELECT lower(u1.address) AS address, max(u1.added_at) AS max_ts + FROM user_threepids AS u1 + INNER JOIN user_threepids AS u2 ON u1.medium = u2.medium AND lower(u1.address) = lower(u2.address) AND u1.address != u2.address + WHERE u1.medium = 'email' AND u2.medium = 'email' + GROUP BY lower(u1.address) + ) AS duplicate_addresses + WHERE + lower(u.address) = duplicate_addresses.address + AND u.added_at != max_ts -- NOT the most recently created + ); + + +-- This update is now safe since we've removed the duplicate addresses. +UPDATE user_threepids SET address = LOWER(address) WHERE medium = 'email'; + + +/* Add an index for the select we do on passwored reset */ +CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); diff --git a/synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql new file mode 100644 index 0000000000..515e6b8e84 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We no longer do this given we back it out again in schema 47 + +-- INSERT into background_updates (update_name, progress_json) +-- VALUES ('event_search_postgres_gist', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql b/synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql new file mode 100644 index 0000000000..74bdc49073 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql @@ -0,0 +1,29 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE appservice_room_list( + appservice_id TEXT NOT NULL, + network_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +-- Each appservice can have multiple published room lists associated with them, +-- keyed of a particular network_id +CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( + appservice_id, network_id, room_id +); + +ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT; +ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT; diff --git a/synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql new file mode 100644 index 0000000000..00be801e90 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); diff --git a/synapse/storage/databases/main/schema/delta/39/event_push_index.sql b/synapse/storage/databases/main/schema/delta/39/event_push_index.sql new file mode 100644 index 0000000000..de2ad93e5c --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/39/event_push_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_push_actions_highlights_index', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/39/federation_out_position.sql b/synapse/storage/databases/main/schema/delta/39/federation_out_position.sql new file mode 100644 index 0000000000..5af814290b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/39/federation_out_position.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + CREATE TABLE federation_stream_position( + type TEXT NOT NULL, + stream_id INTEGER NOT NULL + ); + + INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); + INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; diff --git a/synapse/storage/databases/main/schema/delta/39/membership_profile.sql b/synapse/storage/databases/main/schema/delta/39/membership_profile.sql new file mode 100644 index 0000000000..1bf911c8ab --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/39/membership_profile.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_memberships ADD COLUMN display_name TEXT; +ALTER TABLE room_memberships ADD COLUMN avatar_url TEXT; + +INSERT into background_updates (update_name, progress_json) + VALUES ('room_membership_profile_update', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/40/current_state_idx.sql b/synapse/storage/databases/main/schema/delta/40/current_state_idx.sql new file mode 100644 index 0000000000..7ffa189f39 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/40/current_state_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('current_state_members_idx', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/40/device_inbox.sql b/synapse/storage/databases/main/schema/delta/40/device_inbox.sql new file mode 100644 index 0000000000..b9fe1f0480 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/40/device_inbox.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- turn the pre-fill startup query into a index-only scan on postgresql. +INSERT into background_updates (update_name, progress_json) + VALUES ('device_inbox_stream_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index'); diff --git a/synapse/storage/databases/main/schema/delta/40/device_list_streams.sql b/synapse/storage/databases/main/schema/delta/40/device_list_streams.sql new file mode 100644 index 0000000000..dd6dcb65f1 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/40/device_list_streams.sql @@ -0,0 +1,60 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Cache of remote devices. +CREATE TABLE device_lists_remote_cache ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + content TEXT NOT NULL +); + +-- The last update we got for a user. Empty if we're not receiving updates for +-- that user. +CREATE TABLE device_lists_remote_extremeties ( + user_id TEXT NOT NULL, + stream_id TEXT NOT NULL +); + +-- we used to create non-unique indexes on these tables, but as of update 52 we create +-- unique indexes concurrently: +-- +-- CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); +-- CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); + + +-- Stream of device lists updates. Includes both local and remotes +CREATE TABLE device_lists_stream ( + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); + + +-- The stream of updates to send to other servers. We keep at least one row +-- per user that was sent so that the prev_id for any new updates can be +-- calculated +CREATE TABLE device_lists_outbound_pokes ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + sent BOOLEAN NOT NULL, + ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers +); + +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); diff --git a/synapse/storage/databases/main/schema/delta/40/event_push_summary.sql b/synapse/storage/databases/main/schema/delta/40/event_push_summary.sql new file mode 100644 index 0000000000..3918f0b794 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/40/event_push_summary.sql @@ -0,0 +1,37 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Aggregate of old notification counts that have been deleted out of the +-- main event_push_actions table. This count does not include those that were +-- highlights, as they remain in the event_push_actions table. +CREATE TABLE event_push_summary ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL +); + +CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); + + +-- The stream ordering up to which we have aggregated the event_push_actions +-- table into event_push_summary +CREATE TABLE event_push_summary_stream_ordering ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_ordering BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); diff --git a/synapse/storage/databases/main/schema/delta/40/pushers.sql b/synapse/storage/databases/main/schema/delta/40/pushers.sql new file mode 100644 index 0000000000..054a223f14 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/40/pushers.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag TEXT NOT NULL, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang TEXT, + data TEXT, + last_stream_ordering INTEGER, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) +); + +INSERT INTO pushers2 SELECT * FROM PUSHERS; + +DROP TABLE PUSHERS; + +ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql new file mode 100644 index 0000000000..b7bee8b692 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('device_lists_stream_idx', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql b/synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql new file mode 100644 index 0000000000..62f0b9892b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); diff --git a/synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql new file mode 100644 index 0000000000..5d9cfecf36 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_event_id_idx', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/41/ratelimit.sql b/synapse/storage/databases/main/schema/delta/41/ratelimit.sql new file mode 100644 index 0000000000..a194bf0238 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/41/ratelimit.sql @@ -0,0 +1,22 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE ratelimit_override ( + user_id TEXT NOT NULL, + messages_per_second BIGINT, + burst_count BIGINT +); + +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); diff --git a/synapse/storage/databases/main/schema/delta/42/current_state_delta.sql b/synapse/storage/databases/main/schema/delta/42/current_state_delta.sql new file mode 100644 index 0000000000..d28851aff8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/42/current_state_delta.sql @@ -0,0 +1,26 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE current_state_delta_stream ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT, -- Is null if the key was removed + prev_event_id TEXT -- Is null if the key was added +); + +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); diff --git a/synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql b/synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql new file mode 100644 index 0000000000..9ab8c14fa3 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql @@ -0,0 +1,33 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Table of last stream_id that we sent to destination for user_id. This is +-- used to fill out the `prev_id` fields of outbound device list updates. +CREATE TABLE device_lists_outbound_last_success ( + destination TEXT NOT NULL, + user_id TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +INSERT INTO device_lists_outbound_last_success + SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values + GROUP BY destination, user_id; + +CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( + destination, user_id, stream_id +); diff --git a/synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql b/synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql new file mode 100644 index 0000000000..b8821ac759 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_auth_state_only', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/42/user_dir.py b/synapse/storage/databases/main/schema/delta/42/user_dir.py new file mode 100644 index 0000000000..506f326f4d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/42/user_dir.py @@ -0,0 +1,84 @@ +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +BOTH_TABLES = """ +CREATE TABLE user_directory_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO user_directory_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_directory ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, -- A room_id that we know the user is joined to + display_name TEXT, + avatar_url TEXT +); + +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); + +CREATE TABLE users_in_pubic_room ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL -- A room_id that we know is public +); + +CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id); +CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id); +""" + + +POSTGRES_TABLE = """ +CREATE TABLE user_directory_search ( + user_id TEXT NOT NULL, + vector tsvector +); + +CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector); +CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id); +""" + + +SQLITE_TABLE = """ +CREATE VIRTUAL TABLE user_directory_search + USING fts4 ( user_id, value ); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(BOTH_TABLES.splitlines()): + cur.execute(statement) + + if isinstance(database_engine, PostgresEngine): + for statement in get_statements(POSTGRES_TABLE.splitlines()): + cur.execute(statement) + elif isinstance(database_engine, Sqlite3Engine): + for statement in get_statements(SQLITE_TABLE.splitlines()): + cur.execute(statement) + else: + raise Exception("Unrecognized database engine") + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql b/synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql new file mode 100644 index 0000000000..0e3cd143ff --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql @@ -0,0 +1,21 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE blocked_rooms ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL -- Admin who blocked the room +); + +CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); diff --git a/synapse/storage/databases/main/schema/delta/43/quarantine_media.sql b/synapse/storage/databases/main/schema/delta/43/quarantine_media.sql new file mode 100644 index 0000000000..630907ec4f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/43/quarantine_media.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT; +ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT; diff --git a/synapse/storage/databases/main/schema/delta/43/url_cache.sql b/synapse/storage/databases/main/schema/delta/43/url_cache.sql new file mode 100644 index 0000000000..45ebe020da --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/43/url_cache.sql @@ -0,0 +1,16 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT; diff --git a/synapse/storage/databases/main/schema/delta/43/user_share.sql b/synapse/storage/databases/main/schema/delta/43/user_share.sql new file mode 100644 index 0000000000..ee7062abe4 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/43/user_share.sql @@ -0,0 +1,33 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Table keeping track of who shares a room with who. We only keep track +-- of this for local users, so `user_id` is local users only (but we do keep track +-- of which remote users share a room) +CREATE TABLE users_who_share_rooms ( + user_id TEXT NOT NULL, + other_user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room +); + + +CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id); +CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id); +CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id); + + +-- Make sure that we populate the table initially +UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql b/synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql new file mode 100644 index 0000000000..b12f9b2ebf --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql @@ -0,0 +1,41 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was +-- removed and replaced with 46/local_media_repository_url_idx.sql. +-- +-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; + +-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support +-- indices on expressions until 3.9. +CREATE TABLE local_media_repository_url_cache_new( + url TEXT, + response_code INTEGER, + etag TEXT, + expires_ts BIGINT, + og TEXT, + media_id TEXT, + download_ts BIGINT +); + +INSERT INTO local_media_repository_url_cache_new + SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache; + +DROP TABLE local_media_repository_url_cache; +ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache; + +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); diff --git a/synapse/storage/databases/main/schema/delta/45/group_server.sql b/synapse/storage/databases/main/schema/delta/45/group_server.sql new file mode 100644 index 0000000000..b2333848a0 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/45/group_server.sql @@ -0,0 +1,167 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT +); + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); + + +-- list of users the group server thinks are joined +CREATE TABLE group_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone +); + + +CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id); +CREATE INDEX groups_users_u_idx ON group_users(user_id); + +-- list of users the group server thinks are invited +CREATE TABLE group_invites ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL +); + +CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id); +CREATE INDEX groups_invites_u_idx ON group_invites(user_id); + + +CREATE TABLE group_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone +); + +CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id); +CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id); + + +-- Rooms to include in the summary +CREATE TABLE group_summary_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + category_id TEXT NOT NULL, + room_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone + UNIQUE (group_id, category_id, room_id, room_order), + CHECK (room_order > 0) +); + +CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); + + +-- Categories to include in the summary +CREATE TABLE group_summary_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + cat_order BIGINT NOT NULL, + UNIQUE (group_id, category_id, cat_order), + CHECK (cat_order > 0) +); + +-- The categories in the group +CREATE TABLE group_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone + UNIQUE (group_id, category_id) +); + +-- The users to include in the group summary +CREATE TABLE group_summary_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + role_id TEXT NOT NULL, + user_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the user should be show to everyone +); + +CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); + +-- The roles to include in the group summary +CREATE TABLE group_summary_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + role_order BIGINT NOT NULL, + UNIQUE (group_id, role_id, role_order), + CHECK (role_order > 0) +); + + +-- The roles in a groups +CREATE TABLE group_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone + UNIQUE (group_id, role_id) +); + + +-- List of attestations we've given out and need to renew +CREATE TABLE group_attestations_renewals ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL +); + +CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); +CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); +CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); + + +-- List of attestations we've received from remotes and are interested in. +CREATE TABLE group_attestations_remote ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL, + attestation_json TEXT NOT NULL +); + +CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); +CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); +CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); + + +-- The group membership for the HS's users +CREATE TABLE local_group_membership ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + membership TEXT NOT NULL, + is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership + content TEXT NOT NULL +); + +CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); +CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); + + +CREATE TABLE local_group_updates ( + stream_id BIGINT NOT NULL, + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL +); diff --git a/synapse/storage/databases/main/schema/delta/45/profile_cache.sql b/synapse/storage/databases/main/schema/delta/45/profile_cache.sql new file mode 100644 index 0000000000..e5ddc84df0 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/45/profile_cache.sql @@ -0,0 +1,28 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A subset of remote users whose profiles we have cached. +-- Whether a user is in this table or not is defined by the storage function +-- `is_subscribed_remote_profile_for_user` +CREATE TABLE remote_profile_cache ( + user_id TEXT NOT NULL, + displayname TEXT, + avatar_url TEXT, + last_check BIGINT NOT NULL +); + +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); +CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); diff --git a/synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql new file mode 100644 index 0000000000..68c48a89a9 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* we no longer use (or create) the refresh_tokens table */ +DROP TABLE IF EXISTS refresh_tokens; diff --git a/synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql new file mode 100644 index 0000000000..bb307889c1 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- drop the unique constraint on deleted_pushers so that we can just insert +-- into it rather than upserting. + +CREATE TABLE deleted_pushers2 ( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL +); + +INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id) + SELECT stream_id, app_id, pushkey, user_id from deleted_pushers; + +DROP TABLE deleted_pushers; +ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers; + +-- create the index after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); + diff --git a/synapse/storage/databases/main/schema/delta/46/group_server.sql b/synapse/storage/databases/main/schema/delta/46/group_server.sql new file mode 100644 index 0000000000..097679bc9a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/46/group_server.sql @@ -0,0 +1,32 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups_new ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT, + is_public BOOL NOT NULL -- whether non-members can access group APIs +); + +-- NB: awful hack to get the default to be true on postgres and 1 on sqlite +INSERT INTO groups_new + SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups; + +DROP TABLE groups; +ALTER TABLE groups_new RENAME TO groups; + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); diff --git a/synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql new file mode 100644 index 0000000000..bbfc7f5d1a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will recreate the +-- local_media_repository_url_idx index. +-- +-- We do this as a bg update not because it is a particularly onerous +-- operation, but because we'd like it to be a partial index if possible, and +-- the background_index_update code will understand whether we are on +-- postgres or sqlite and behave accordingly. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('local_media_repository_url_idx', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql new file mode 100644 index 0000000000..cb0d5a2576 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- change the user_directory table to also cover global local user profiles +-- rather than just profiles within specific rooms. + +CREATE TABLE user_directory2 ( + user_id TEXT NOT NULL, + room_id TEXT, + display_name TEXT, + avatar_url TEXT +); + +INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url) + SELECT user_id, room_id, display_name, avatar_url from user_directory; + +DROP TABLE user_directory; +ALTER TABLE user_directory2 RENAME TO user_directory; + +-- create indexes after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); diff --git a/synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql b/synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql new file mode 100644 index 0000000000..d9505f8da1 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this is just embarassing :| +ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms; + +-- this is only 300K rows on matrix.org and takes ~3s to generate the index, +-- so is hopefully not going to block anyone else for that long... +CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id); +CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id); +DROP INDEX users_in_pubic_room_room_idx; +DROP INDEX users_in_pubic_room_user_idx; diff --git a/synapse/storage/databases/main/schema/delta/47/last_access_media.sql b/synapse/storage/databases/main/schema/delta/47/last_access_media.sql new file mode 100644 index 0000000000..f505fb22b5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/47/last_access_media.sql @@ -0,0 +1,16 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql new file mode 100644 index 0000000000..31d7a817eb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_postgres_gin', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql b/synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql new file mode 100644 index 0000000000..edccf4a96f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql @@ -0,0 +1,28 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Temporary staging area for push actions that have been calculated for an +-- event, but the event hasn't yet been persisted. +-- When the event is persisted the rows are moved over to the +-- event_push_actions table. +CREATE TABLE event_push_actions_staging ( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + actions TEXT NOT NULL, + notif SMALLINT NOT NULL, + highlight SMALLINT NOT NULL +); + +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); diff --git a/synapse/storage/databases/main/schema/delta/48/add_user_consent.sql b/synapse/storage/databases/main/schema/delta/48/add_user_consent.sql new file mode 100644 index 0000000000..5237491506 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/48/add_user_consent.sql @@ -0,0 +1,18 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* record the version of the privacy policy the user has consented to + */ +ALTER TABLE users ADD COLUMN consent_version TEXT; diff --git a/synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql new file mode 100644 index 0000000000..9248b0b24a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('user_ips_last_seen_index', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/48/deactivated_users.sql b/synapse/storage/databases/main/schema/delta/48/deactivated_users.sql new file mode 100644 index 0000000000..e9013a6969 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/48/deactivated_users.sql @@ -0,0 +1,25 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Store any accounts that have been requested to be deactivated. + * We part the account from all the rooms its in when its + * deactivated. This can take some time and synapse may be restarted + * before it completes, so store the user IDs here until the process + * is complete. + */ +CREATE TABLE users_pending_deactivation ( + user_id TEXT NOT NULL +); diff --git a/synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py b/synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py new file mode 100644 index 0000000000..49f5f2c003 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py @@ -0,0 +1,63 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +FIX_INDEXES = """ +-- rebuild indexes as uniques +DROP INDEX groups_invites_g_idx; +CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); +DROP INDEX groups_users_g_idx; +CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); + +-- rename other indexes to actually match their table names.. +DROP INDEX groups_users_u_idx; +CREATE INDEX group_users_u_idx ON group_users(user_id); +DROP INDEX groups_invites_u_idx; +CREATE INDEX group_invites_u_idx ON group_invites(user_id); +DROP INDEX groups_rooms_g_idx; +CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); +DROP INDEX groups_rooms_r_idx; +CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" + + # remove duplicates from group_users & group_invites tables + cur.execute( + """ + DELETE FROM group_users WHERE %s NOT IN ( + SELECT min(%s) FROM group_users GROUP BY group_id, user_id + ); + """ + % (rowid, rowid) + ) + cur.execute( + """ + DELETE FROM group_invites WHERE %s NOT IN ( + SELECT min(%s) FROM group_invites GROUP BY group_id, user_id + ); + """ + % (rowid, rowid) + ) + + for statement in get_statements(FIX_INDEXES.splitlines()): + cur.execute(statement) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/48/groups_joinable.sql b/synapse/storage/databases/main/schema/delta/48/groups_joinable.sql new file mode 100644 index 0000000000..ce26eaf0c9 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/48/groups_joinable.sql @@ -0,0 +1,22 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This isn't a real ENUM because sqlite doesn't support it + * and we use a default of NULL for inserted rows and interpret + * NULL at the python store level as necessary so that existing + * rows are given the correct default policy. + */ +ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite'; diff --git a/synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql new file mode 100644 index 0000000000..14dcf18d73 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql @@ -0,0 +1,20 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* record whether we have sent a server notice about consenting to the + * privacy policy. Specifically records the version of the policy we sent + * a message about. + */ +ALTER TABLE users ADD COLUMN consent_server_notice_sent TEXT; diff --git a/synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql new file mode 100644 index 0000000000..3dd478196f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql @@ -0,0 +1,21 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, + device_id TEXT, + timestamp BIGINT NOT NULL ); +CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); +CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); diff --git a/synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql new file mode 100644 index 0000000000..3a4ed59b5b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('user_ips_last_seen_only_index', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql new file mode 100644 index 0000000000..c93ae47532 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + +INSERT into background_updates (update_name, progress_json) + VALUES ('users_creation_ts', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/50/erasure_store.sql b/synapse/storage/databases/main/schema/delta/50/erasure_store.sql new file mode 100644 index 0000000000..5d8641a9ab --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/50/erasure_store.sql @@ -0,0 +1,21 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a table of users who have requested that their details be erased +CREATE TABLE erased_users ( + user_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id); diff --git a/synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py b/synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py new file mode 100644 index 0000000000..b1684a8441 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +We want to stop populating 'event.content', so we need to make it nullable. + +If this has to be rolled back, then the following should populate the missing data: + +Postgres: + + UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej + WHERE ej.event_id = events.event_id AND + stream_ordering < ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering LIMIT 1 + ); + + UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej + WHERE ej.event_id = events.event_id AND + stream_ordering > ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering DESC LIMIT 1 + ); + +SQLite: + + UPDATE events SET content=( + SELECT json_extract(json,'$.content') FROM event_json ej + WHERE ej.event_id = events.event_id + ) + WHERE + stream_ordering < ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering LIMIT 1 + ) + OR stream_ordering > ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering DESC LIMIT 1 + ); + +""" + +import logging + +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + pass + + +def run_upgrade(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute( + """ + ALTER TABLE events ALTER COLUMN content DROP NOT NULL; + """ + ) + return + + # sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html + + cur.execute( + "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'" + ) + (oldsql,) = cur.fetchone() + + sql = oldsql.replace("content TEXT NOT NULL", "content TEXT") + if sql == oldsql: + raise Exception("Couldn't find null constraint to drop in %s" % oldsql) + + logger.info("Replacing definition of 'events' with: %s", sql) + + cur.execute("PRAGMA schema_version") + (oldver,) = cur.fetchone() + cur.execute("PRAGMA writable_schema=ON") + cur.execute( + "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", + (sql,), + ) + cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) + cur.execute("PRAGMA writable_schema=OFF") diff --git a/synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql b/synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql new file mode 100644 index 0000000000..c0e66a697d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- users' optionally backed up encrypted e2e sessions +CREATE TABLE e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + version TEXT NOT NULL, + first_message_index INT, + forwarded_count INT, + is_verified BOOLEAN, + session_data TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); + +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE e2e_room_keys_versions ( + user_id TEXT NOT NULL, + version TEXT NOT NULL, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); diff --git a/synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql b/synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql new file mode 100644 index 0000000000..c9d537d5a3 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql @@ -0,0 +1,27 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a table of monthly active users, for use where blocking based on mau limits +CREATE TABLE monthly_active_users ( + user_id TEXT NOT NULL, + -- Last time we saw the user. Not guaranteed to be accurate due to rate limiting + -- on updates, Granularity of updates governed by + -- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY + -- Measured in ms since epoch. + timestamp BIGINT NOT NULL +); + +CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); diff --git a/synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql new file mode 100644 index 0000000000..91e03d13e1 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is needed to efficiently check for unreferenced state groups during +-- purge. Added events_to_state_group(state_group) index +INSERT into background_updates (update_name, progress_json) + VALUES ('event_to_state_groups_sg_index', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql new file mode 100644 index 0000000000..bfa49e6f92 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql @@ -0,0 +1,36 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will create a unique index on +-- device_lists_remote_cache +INSERT into background_updates (update_name, progress_json) + VALUES ('device_lists_remote_cache_unique_idx', '{}'); + +-- and one on device_lists_remote_extremeties +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ( + 'device_lists_remote_extremeties_unique_idx', '{}', + + -- doesn't really depend on this, but we need to make sure both happen + -- before we drop the old indexes. + 'device_lists_remote_cache_unique_idx' + ); + +-- once they complete, we can drop the old indexes. +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ( + 'drop_device_list_streams_non_unique_indexes', '{}', + 'device_lists_remote_extremeties_unique_idx' + ); diff --git a/synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql b/synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql new file mode 100644 index 0000000000..db687cccae --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql @@ -0,0 +1,53 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Change version column to an integer so we can do MAX() sensibly + */ +CREATE TABLE e2e_room_keys_versions_new ( + user_id TEXT NOT NULL, + version BIGINT NOT NULL, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +INSERT INTO e2e_room_keys_versions_new + SELECT user_id, CAST(version as BIGINT), algorithm, auth_data, deleted FROM e2e_room_keys_versions; + +DROP TABLE e2e_room_keys_versions; +ALTER TABLE e2e_room_keys_versions_new RENAME TO e2e_room_keys_versions; + +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); + +/* Change e2e_rooms_keys to match + */ +CREATE TABLE e2e_room_keys_new ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + version BIGINT NOT NULL, + first_message_index INT, + forwarded_count INT, + is_verified BOOLEAN, + session_data TEXT NOT NULL +); + +INSERT INTO e2e_room_keys_new + SELECT user_id, room_id, session_id, CAST(version as BIGINT), first_message_index, forwarded_count, is_verified, session_data FROM e2e_room_keys; + +DROP TABLE e2e_room_keys; +ALTER TABLE e2e_room_keys_new RENAME TO e2e_room_keys; + +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); diff --git a/synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql new file mode 100644 index 0000000000..88ec2f83e5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql @@ -0,0 +1,19 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* The type of the user: NULL for a regular user, or one of the constants in + * synapse.api.constants.UserTypes + */ +ALTER TABLE users ADD COLUMN user_type TEXT DEFAULT NULL; diff --git a/synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql new file mode 100644 index 0000000000..e372f5a44a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql @@ -0,0 +1,16 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP TABLE IF EXISTS sent_transactions; diff --git a/synapse/storage/databases/main/schema/delta/53/event_format_version.sql b/synapse/storage/databases/main/schema/delta/53/event_format_version.sql new file mode 100644 index 0000000000..1d977c2834 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/53/event_format_version.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE event_json ADD COLUMN format_version INTEGER; diff --git a/synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql b/synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql new file mode 100644 index 0000000000..ffcc896b58 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Set up staging tables +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_user_directory_createtables', '{}'); + +-- Run through each room and update the user directory according to who is in it +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_user_directory_process_rooms', '{}', 'populate_user_directory_createtables'); + +-- Insert all users, if search_all_users is on +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_user_directory_process_users', '{}', 'populate_user_directory_process_rooms'); + +-- Clean up staging tables +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_user_directory_cleanup', '{}', 'populate_user_directory_process_users'); diff --git a/synapse/storage/databases/main/schema/delta/53/user_ips_index.sql b/synapse/storage/databases/main/schema/delta/53/user_ips_index.sql new file mode 100644 index 0000000000..b812c5794f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/53/user_ips_index.sql @@ -0,0 +1,30 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + -- analyze user_ips, to help ensure the correct indices are used +INSERT INTO background_updates (update_name, progress_json) VALUES + ('user_ips_analyze', '{}'); + +-- delete duplicates +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('user_ips_remove_dupes', '{}', 'user_ips_analyze'); + +-- add a new unique index to user_ips table +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('user_ips_device_unique_index', '{}', 'user_ips_remove_dupes'); + +-- drop the old original index +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index'); diff --git a/synapse/storage/databases/main/schema/delta/53/user_share.sql b/synapse/storage/databases/main/schema/delta/53/user_share.sql new file mode 100644 index 0000000000..5831b1a6f8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/53/user_share.sql @@ -0,0 +1,44 @@ +/* Copyright 2017 Vector Creations Ltd, 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Old disused version of the tables below. +DROP TABLE IF EXISTS users_who_share_rooms; + +-- Tables keeping track of what users share rooms. This is a map of local users +-- to local or remote users, per room. Remote users cannot be in the user_id +-- column, only the other_user_id column. There are two tables, one for public +-- rooms and those for private rooms. +CREATE TABLE IF NOT EXISTS users_who_share_public_rooms ( + user_id TEXT NOT NULL, + other_user_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS users_who_share_private_rooms ( + user_id TEXT NOT NULL, + other_user_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX users_who_share_public_rooms_u_idx ON users_who_share_public_rooms(user_id, other_user_id, room_id); +CREATE INDEX users_who_share_public_rooms_r_idx ON users_who_share_public_rooms(room_id); +CREATE INDEX users_who_share_public_rooms_o_idx ON users_who_share_public_rooms(other_user_id); + +CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id); +CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id); +CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id); + +-- Make sure that we populate the tables initially by resetting the stream ID +UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql b/synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql new file mode 100644 index 0000000000..80c2c573b6 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql @@ -0,0 +1,29 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks which identity server a user bound their threepid via. +CREATE TABLE user_threepid_id_server ( + user_id TEXT NOT NULL, + medium TEXT NOT NULL, + address TEXT NOT NULL, + id_server TEXT NOT NULL +); + +CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server( + user_id, medium, address, id_server +); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('user_threepids_grandfather', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql new file mode 100644 index 0000000000..f7827ca6d2 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql @@ -0,0 +1,28 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We don't need the old version of this table. +DROP TABLE IF EXISTS users_in_public_rooms; + +-- Old version of users_in_public_rooms +DROP TABLE IF EXISTS users_who_share_public_rooms; + +-- Track what users are in public rooms. +CREATE TABLE IF NOT EXISTS users_in_public_rooms ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id); diff --git a/synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql new file mode 100644 index 0000000000..0adb2ad55e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We previously changed the schema for this table without renaming the file, which means +-- that some databases might still be using the old schema. This ensures Synapse uses the +-- right schema for the table. +DROP TABLE IF EXISTS account_validity; + +-- Track what users are in public rooms. +CREATE TABLE IF NOT EXISTS account_validity ( + user_id TEXT PRIMARY KEY, + expiration_ts_ms BIGINT NOT NULL, + email_sent BOOLEAN NOT NULL, + renewal_token TEXT +); + +CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms) +CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token) diff --git a/synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql new file mode 100644 index 0000000000..c01aa9d2d9 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql @@ -0,0 +1,23 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* When we can use this key until, before we have to refresh it. */ +ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT; + +UPDATE server_signature_keys SET ts_valid_until_ms = ( + SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE + skj.server_name = server_signature_keys.server_name AND + skj.key_id = server_signature_keys.key_id +); diff --git a/synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql new file mode 100644 index 0000000000..b062ec840c --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql @@ -0,0 +1,23 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Start a background job to cleanup extremities that were incorrectly added +-- by bug #5269. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('delete_soft_failed_extremities', '{}'); + +DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent. +CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; +CREATE INDEX _extremities_to_check_id ON _extremities_to_check(event_id); diff --git a/synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql new file mode 100644 index 0000000000..dbbe682697 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- we need to do this first due to foreign constraints +DROP TABLE IF EXISTS application_services_regex; + +DROP TABLE IF EXISTS application_services; +DROP TABLE IF EXISTS transaction_id_to_pdu; +DROP TABLE IF EXISTS stats_reporting; +DROP TABLE IF EXISTS current_state_resets; +DROP TABLE IF EXISTS event_content_hashes; +DROP TABLE IF EXISTS event_destinations; +DROP TABLE IF EXISTS event_edge_hashes; +DROP TABLE IF EXISTS event_signatures; +DROP TABLE IF EXISTS feedback; +DROP TABLE IF EXISTS room_hosts; +DROP TABLE IF EXISTS server_tls_certificates; +DROP TABLE IF EXISTS state_forward_extremities; diff --git a/synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql b/synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql new file mode 100644 index 0000000000..e6ee70c623 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP TABLE IF EXISTS presence_list; diff --git a/synapse/storage/databases/main/schema/delta/54/relations.sql b/synapse/storage/databases/main/schema/delta/54/relations.sql new file mode 100644 index 0000000000..134862b870 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/54/relations.sql @@ -0,0 +1,27 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks related events, like reactions, replies, edits, etc. Note that things +-- in this table are not necessarily "valid", e.g. it may contain edits from +-- people who don't have power to edit other peoples events. +CREATE TABLE IF NOT EXISTS event_relations ( + event_id TEXT NOT NULL, + relates_to_id TEXT NOT NULL, + relation_type TEXT NOT NULL, + aggregation_key TEXT +); + +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); +CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); diff --git a/synapse/storage/databases/main/schema/delta/54/stats.sql b/synapse/storage/databases/main/schema/delta/54/stats.sql new file mode 100644 index 0000000000..652e58308e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/54/stats.sql @@ -0,0 +1,80 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stats_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO stats_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_stats ( + user_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + public_rooms INT NOT NULL, + private_rooms INT NOT NULL +); + +CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); + +CREATE TABLE room_stats ( + room_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + state_events INT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); + +-- cache of current room state; useful for the publicRooms list +CREATE TABLE room_state ( + room_id TEXT NOT NULL, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + name TEXT, + topic TEXT, + avatar TEXT, + canonical_alias TEXT + -- get aliases straight from the right table +); + +CREATE UNIQUE INDEX room_state_room ON room_state(room_id); + +CREATE TABLE room_stats_earliest_token ( + room_id TEXT NOT NULL, + token BIGINT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); + +-- Set up staging tables +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_stats_createtables', '{}'); + +-- Run through each room and update stats +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', 'populate_stats_createtables'); + +-- Clean up staging tables +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/databases/main/schema/delta/54/stats2.sql b/synapse/storage/databases/main/schema/delta/54/stats2.sql new file mode 100644 index 0000000000..3b2d48447f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/54/stats2.sql @@ -0,0 +1,28 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This delta file gets run after `54/stats.sql` delta. + +-- We want to add some indices to the temporary stats table, so we re-insert +-- 'populate_stats_createtables' if we are still processing the rooms update. +INSERT INTO background_updates (update_name, progress_json) + SELECT 'populate_stats_createtables', '{}' + WHERE + 'populate_stats_process_rooms' IN ( + SELECT update_name FROM background_updates + ) + AND 'populate_stats_createtables' NOT IN ( -- don't insert if already exists + SELECT update_name FROM background_updates + ); diff --git a/synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql b/synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql new file mode 100644 index 0000000000..4590604bfd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- when this access token can be used until, in ms since the epoch. NULL means the token +-- never expires. +ALTER TABLE access_tokens ADD COLUMN valid_until_ms BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql b/synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql new file mode 100644 index 0000000000..a8eced2e0a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql @@ -0,0 +1,31 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS threepid_validation_session ( + session_id TEXT PRIMARY KEY, + medium TEXT NOT NULL, + address TEXT NOT NULL, + client_secret TEXT NOT NULL, + last_send_attempt BIGINT NOT NULL, + validated_at BIGINT +); + +CREATE TABLE IF NOT EXISTS threepid_validation_token ( + token TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + next_link TEXT, + expires BIGINT NOT NULL +); + +CREATE INDEX threepid_validation_token_session_id ON threepid_validation_token(session_id); diff --git a/synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql new file mode 100644 index 0000000000..dabdde489b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql @@ -0,0 +1,19 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE users ADD deactivated SMALLINT DEFAULT 0 NOT NULL; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('users_set_deactivated_flag', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql new file mode 100644 index 0000000000..41807eb1e7 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Opentracing context data for inclusion in the device_list_update EDUs, as a + * json-encoded dictionary. NULL if opentracing is disabled (or not enabled for this destination). + */ +ALTER TABLE device_lists_outbound_pokes ADD opentracing_context TEXT; diff --git a/synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql new file mode 100644 index 0000000000..473018676f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql @@ -0,0 +1,22 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We add membership to current state so that we don't need to join against +-- room_memberships, which can be surprisingly costly (we do such queries +-- very frequently). +-- This will be null for non-membership events and the content.membership key +-- for membership events. (Will also be null for membership events until the +-- background update job has finished). +ALTER TABLE current_state_events ADD membership TEXT; diff --git a/synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql new file mode 100644 index 0000000000..3133d42d4a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql @@ -0,0 +1,24 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We add membership to current state so that we don't need to join against +-- room_memberships, which can be surprisingly costly (we do such queries +-- very frequently). +-- This will be null for non-membership events and the content.membership key +-- for membership events. (Will also be null for membership events until the +-- background update job has finished). + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('current_state_events_membership', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql new file mode 100644 index 0000000000..1d2ddb1b1a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* delete room keys that belong to deleted room key version, or to room key + * versions that don't exist (anymore) + */ +DELETE FROM e2e_room_keys +WHERE version NOT IN ( + SELECT version + FROM e2e_room_keys_versions + WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id + AND e2e_room_keys_versions.deleted = 0 +); diff --git a/synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql new file mode 100644 index 0000000000..f00889290b --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Record the timestamp when a given server started failing + */ +ALTER TABLE destinations ADD failure_ts BIGINT; + +/* as a rough approximation, we assume that the server started failing at + * retry_interval before the last retry + */ +UPDATE destinations SET failure_ts = retry_last_ts - retry_interval + WHERE retry_last_ts > 0; diff --git a/synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres new file mode 100644 index 0000000000..b9bbb18a91 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres @@ -0,0 +1,18 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We want to store large retry intervals so we upgrade the column from INT +-- to BIGINT. We don't need to do this on SQLite. +ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql new file mode 100644 index 0000000000..c2f557fde9 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This line already existed in deltas/35/device_stream_id but was not included in the +-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist +INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS ( + SELECT * from device_max_stream_id +); \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql new file mode 100644 index 0000000000..dfa902d0ba --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql @@ -0,0 +1,24 @@ +/* Copyright 2019 Matrix.org Foundation CIC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track last seen information for a device in the devices table, rather +-- than relying on it being in the user_ips table (which we want to be able +-- to purge old entries from) +ALTER TABLE devices ADD COLUMN last_seen BIGINT; +ALTER TABLE devices ADD COLUMN ip TEXT; +ALTER TABLE devices ADD COLUMN user_agent TEXT; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('devices_last_seen', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql new file mode 100644 index 0000000000..9f09922c67 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- these tables are never used. +DROP TABLE IF EXISTS room_names; +DROP TABLE IF EXISTS topics; +DROP TABLE IF EXISTS history_visibility; +DROP TABLE IF EXISTS guest_access; diff --git a/synapse/storage/databases/main/schema/delta/56/event_expiry.sql b/synapse/storage/databases/main/schema/delta/56/event_expiry.sql new file mode 100644 index 0000000000..81a36a8b1d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/event_expiry.sql @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_expiry ( + event_id TEXT PRIMARY KEY, + expiry_ts BIGINT NOT NULL +); + +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql new file mode 100644 index 0000000000..5e29c1da19 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- room_id and topoligical_ordering are denormalised from the events table in order to +-- make the index work. +CREATE TABLE IF NOT EXISTS event_labels ( + event_id TEXT, + label TEXT, + room_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + PRIMARY KEY(event_id, label) +); + + +-- This index enables an event pagination looking for a particular label to index the +-- event_labels table first, which is much quicker than scanning the events table and then +-- filtering by label, if the label is rarely used relative to the size of the room. +CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering); diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql new file mode 100644 index 0000000000..5f5e0499ae --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_store_labels', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql new file mode 100644 index 0000000000..014cb3b538 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 Matrix.org Foundation CIC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- version is supposed to be part of the room keys index +CREATE UNIQUE INDEX e2e_room_keys_with_version_idx ON e2e_room_keys(user_id, version, room_id, session_id); +DROP INDEX IF EXISTS e2e_room_keys_idx; diff --git a/synapse/storage/databases/main/schema/delta/56/hidden_devices.sql b/synapse/storage/databases/main/schema/delta/56/hidden_devices.sql new file mode 100644 index 0000000000..67f8b20297 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/hidden_devices.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- device list needs to know which ones are "real" devices, and which ones are +-- just used to avoid collisions +ALTER TABLE devices ADD COLUMN hidden BOOLEAN DEFAULT FALSE; diff --git a/synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite new file mode 100644 index 0000000000..e8b1fd35d8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite @@ -0,0 +1,42 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Change the hidden column from a default value of FALSE to a default value of + * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the + * string 'FALSE', which is truthy. + * + * Since sqlite doesn't allow us to just change the default value, we have to + * recreate the table, copy the data, fix the rows that have incorrect data, and + * replace the old table with the new table. + */ + +CREATE TABLE IF NOT EXISTS devices2 ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + last_seen BIGINT, + ip TEXT, + user_agent TEXT, + hidden BOOLEAN DEFAULT 0, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); + +INSERT INTO devices2 SELECT * FROM devices; + +UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE'; + +DROP TABLE devices; + +ALTER TABLE devices2 RENAME TO devices; diff --git a/synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql new file mode 100644 index 0000000000..4f24c1405d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql @@ -0,0 +1,29 @@ +/* Copyright 2019 Werner Sembach + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made. +DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); diff --git a/synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql new file mode 100644 index 0000000000..7be31ffebb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id); diff --git a/synapse/storage/databases/main/schema/delta/56/redaction_censor.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor.sql new file mode 100644 index 0000000000..ea95db0ed7 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false; diff --git a/synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql new file mode 100644 index 0000000000..49ce35d794 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql @@ -0,0 +1,22 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE redactions ADD COLUMN received_ts BIGINT; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('redactions_received_ts', '{}'); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('redactions_have_censored_ts_idx', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres new file mode 100644 index 0000000000..67471f3ef5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- There was a bug where we may have updated censored redactions as bytes, +-- which can (somehow) cause json to be inserted hex encoded. These updates go +-- and undoes any such hex encoded JSON. + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_fix_redactions_bytes_create_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index'); diff --git a/synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql new file mode 100644 index 0000000000..b7550f6f4e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS redactions_have_censored; diff --git a/synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql new file mode 100644 index 0000000000..aeb17813d3 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Now that #6232 is a thing, we can remove old rooms from the directory. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('remove_tombstoned_rooms_from_directory', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/56/room_key_etag.sql b/synapse/storage/databases/main/schema/delta/56/room_key_etag.sql new file mode 100644 index 0000000000..7d70dd071e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/room_key_etag.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- store the current etag of backup version +ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql b/synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql new file mode 100644 index 0000000000..92ab1f5e65 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Adds an index on room_memberships for fetching all forgotten rooms for a user +INSERT INTO background_updates (update_name, progress_json) VALUES + ('room_membership_forgotten_idx', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/56/room_retention.sql b/synapse/storage/databases/main/schema/delta/56/room_retention.sql new file mode 100644 index 0000000000..ee6cdf7a14 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/room_retention.sql @@ -0,0 +1,33 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks the retention policy of a room. +-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in +-- the room's retention policy state event. +-- If a room doesn't have a retention policy state event in its state, both max_lifetime +-- and min_lifetime are NULL. +CREATE TABLE IF NOT EXISTS room_retention( + room_id TEXT, + event_id TEXT, + min_lifetime BIGINT, + max_lifetime BIGINT, + + PRIMARY KEY(room_id, event_id) +); + +CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('insert_room_retention', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/56/signing_keys.sql b/synapse/storage/databases/main/schema/delta/56/signing_keys.sql new file mode 100644 index 0000000000..5c5fffcafb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/signing_keys.sql @@ -0,0 +1,56 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- cross-signing keys +CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys ( + user_id TEXT NOT NULL, + -- the type of cross-signing key (master, user_signing, or self_signing) + keytype TEXT NOT NULL, + -- the full key information, as a json-encoded dict + keydata TEXT NOT NULL, + -- for keeping the keys in order, so that we can fetch the latest one + stream_id BIGINT NOT NULL +); + +CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, stream_id); + +-- cross-signing signatures +CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( + -- user who did the signing + user_id TEXT NOT NULL, + -- key used to sign + key_id TEXT NOT NULL, + -- user who was signed + target_user_id TEXT NOT NULL, + -- device/key that was signed + target_device_id TEXT NOT NULL, + -- the actual signature + signature TEXT NOT NULL +); + +-- replaced by the index created in signing_keys_nonunique_signatures.sql +-- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); + +-- stream of user signature updates +CREATE TABLE IF NOT EXISTS user_signature_stream ( + -- uses the same stream ID as device list stream + stream_id BIGINT NOT NULL, + -- user who did the signing + from_user_id TEXT NOT NULL, + -- list of users who were signed, as a JSON array + user_ids TEXT NOT NULL +); + +CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id); diff --git a/synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql new file mode 100644 index 0000000000..0aa90ebf0c --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql @@ -0,0 +1,22 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* The cross-signing signatures index should not be a unique index, because a + * user may upload multiple signatures for the same target user. The previous + * index was unique, so delete it if it's there and create a new non-unique + * index. */ + +DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT +EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); diff --git a/synapse/storage/databases/main/schema/delta/56/stats_separated.sql b/synapse/storage/databases/main/schema/delta/56/stats_separated.sql new file mode 100644 index 0000000000..bbdde121e8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/stats_separated.sql @@ -0,0 +1,156 @@ +/* Copyright 2018 New Vector Ltd + * Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +----- First clean up from previous versions of room stats. + +-- First remove old stats stuff +DROP TABLE IF EXISTS room_stats; +DROP TABLE IF EXISTS room_state; +DROP TABLE IF EXISTS room_stats_state; +DROP TABLE IF EXISTS user_stats; +DROP TABLE IF EXISTS room_stats_earliest_tokens; +DROP TABLE IF EXISTS _temp_populate_stats_position; +DROP TABLE IF EXISTS _temp_populate_stats_rooms; +DROP TABLE IF EXISTS stats_stream_pos; + +-- Unschedule old background updates if they're still scheduled +DELETE FROM background_updates WHERE update_name IN ( + 'populate_stats_createtables', + 'populate_stats_process_rooms', + 'populate_stats_process_users', + 'populate_stats_cleanup' +); + +-- this relies on current_state_events.membership having been populated, so add +-- a dependency on current_state_events_membership. +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', 'current_state_events_membership'); + +-- this also relies on current_state_events.membership having been populated, but +-- we get that as a side-effect of depending on populate_stats_process_rooms. +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); + +----- Create tables for our version of room stats. + +-- single-row table to track position of incremental updates +DROP TABLE IF EXISTS stats_incremental_position; +CREATE TABLE stats_incremental_position ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +-- insert a null row and make sure it is the only one. +INSERT INTO stats_incremental_position ( + stream_id +) SELECT COALESCE(MAX(stream_ordering), 0) from events; + +-- represents PRESENT room statistics for a room +-- only holds absolute fields +DROP TABLE IF EXISTS room_stats_current; +CREATE TABLE room_stats_current ( + room_id TEXT NOT NULL PRIMARY KEY, + + -- These are absolute counts + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + + local_users_in_room INT NOT NULL, + + -- The maximum delta stream position that this row takes into account. + completed_delta_stream_id BIGINT NOT NULL +); + + +-- represents HISTORICAL room statistics for a room +DROP TABLE IF EXISTS room_stats_historical; +CREATE TABLE room_stats_historical ( + room_id TEXT NOT NULL, + -- These stats cover the time from (end_ts - bucket_size)...end_ts (in ms). + -- Note that end_ts is quantised. + end_ts BIGINT NOT NULL, + bucket_size BIGINT NOT NULL, + + -- These stats are absolute counts + current_state_events BIGINT NOT NULL, + joined_members BIGINT NOT NULL, + invited_members BIGINT NOT NULL, + left_members BIGINT NOT NULL, + banned_members BIGINT NOT NULL, + local_users_in_room BIGINT NOT NULL, + + -- These stats are per time slice + total_events BIGINT NOT NULL, + total_event_bytes BIGINT NOT NULL, + + PRIMARY KEY (room_id, end_ts) +); + +-- We use this index to speed up deletion of ancient room stats. +CREATE INDEX room_stats_historical_end_ts ON room_stats_historical (end_ts); + +-- represents PRESENT statistics for a user +-- only holds absolute fields +DROP TABLE IF EXISTS user_stats_current; +CREATE TABLE user_stats_current ( + user_id TEXT NOT NULL PRIMARY KEY, + + joined_rooms BIGINT NOT NULL, + + -- The maximum delta stream position that this row takes into account. + completed_delta_stream_id BIGINT NOT NULL +); + +-- represents HISTORICAL statistics for a user +DROP TABLE IF EXISTS user_stats_historical; +CREATE TABLE user_stats_historical ( + user_id TEXT NOT NULL, + end_ts BIGINT NOT NULL, + bucket_size BIGINT NOT NULL, + + joined_rooms BIGINT NOT NULL, + + invites_sent BIGINT NOT NULL, + rooms_created BIGINT NOT NULL, + total_events BIGINT NOT NULL, + total_event_bytes BIGINT NOT NULL, + + PRIMARY KEY (user_id, end_ts) +); + +-- We use this index to speed up deletion of ancient user stats. +CREATE INDEX user_stats_historical_end_ts ON user_stats_historical (end_ts); + + +CREATE TABLE room_stats_state ( + room_id TEXT NOT NULL, + name TEXT, + canonical_alias TEXT, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + avatar TEXT, + guest_access TEXT, + is_federatable BOOLEAN, + topic TEXT +); + +CREATE UNIQUE INDEX room_stats_state_room ON room_stats_state(room_id); diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py new file mode 100644 index 0000000000..1de8b54961 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py @@ -0,0 +1,52 @@ +import logging + +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +""" +This migration updates the user_filters table as follows: + + - drops any (user_id, filter_id) duplicates + - makes the columns NON-NULLable + - turns the index into a UNIQUE index +""" + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + select_clause = """ + SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json + FROM user_filters + """ + else: + select_clause = """ + SELECT * FROM user_filters GROUP BY user_id, filter_id + """ + sql = """ + DROP TABLE IF EXISTS user_filters_migration; + DROP INDEX IF EXISTS user_filters_unique; + CREATE TABLE user_filters_migration ( + user_id TEXT NOT NULL, + filter_id BIGINT NOT NULL, + filter_json BYTEA NOT NULL + ); + INSERT INTO user_filters_migration (user_id, filter_id, filter_json) + %s; + CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration + (user_id, filter_id); + DROP TABLE user_filters; + ALTER TABLE user_filters_migration RENAME TO user_filters; + """ % ( + select_clause, + ) + + if isinstance(database_engine, PostgresEngine): + cur.execute(sql) + else: + cur.executescript(sql) diff --git a/synapse/storage/databases/main/schema/delta/56/user_external_ids.sql b/synapse/storage/databases/main/schema/delta/56/user_external_ids.sql new file mode 100644 index 0000000000..91390c4527 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/user_external_ids.sql @@ -0,0 +1,24 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * a table which records mappings from external auth providers to mxids + */ +CREATE TABLE IF NOT EXISTS user_external_ids ( + auth_provider TEXT NOT NULL, + external_id TEXT NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (auth_provider, external_id) +); diff --git a/synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql new file mode 100644 index 0000000000..149f8be8b6 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 Matrix.org Foundation CIC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this was apparently forgotten when the table was created back in delta 53. +CREATE INDEX users_in_public_rooms_r_idx ON users_in_public_rooms(room_id); diff --git a/synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql b/synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql new file mode 100644 index 0000000000..aec06c8261 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add background update to go and delete current state events for rooms the +-- server is no longer in. +-- +-- this relies on the 'membership' column of current_state_events, so make sure +-- that's populated first! +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('delete_old_current_state_events', '{}', 'current_state_events_membership'); diff --git a/synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql b/synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql new file mode 100644 index 0000000000..c3b6de2099 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Records whether the server thinks that the remote users cached device lists +-- may be out of date (e.g. if we have received a to device message from a +-- device we don't know about). +CREATE TABLE IF NOT EXISTS device_lists_remote_resync ( + user_id TEXT NOT NULL, + added_ts BIGINT NOT NULL +); + +CREATE UNIQUE INDEX device_lists_remote_resync_idx ON device_lists_remote_resync (user_id); +CREATE INDEX device_lists_remote_resync_ts_idx ON device_lists_remote_resync (added_ts); diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py new file mode 100644 index 0000000000..63b5acdcf7 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# We create a new table called `local_current_membership` that stores the latest +# membership state of local users in rooms, which helps track leaves/bans/etc +# even if the server has left the room (and so has deleted the room from +# `current_state_events`). This will also include outstanding invites for local +# users for rooms the server isn't in. +# +# If the server isn't and hasn't been in the room then it will only include +# outsstanding invites, and not e.g. pre-emptive bans of local users. +# +# If the server later rejoins a room `local_current_membership` can simply be +# replaced with the new current state of the room (which results in the +# equivalent behaviour as if the server had remained in the room). + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + # We need to do the insert in `run_upgrade` section as we don't have access + # to `config` in `run_create`. + + # This upgrade may take a bit of time for large servers (e.g. one minute for + # matrix.org) but means we avoid a lots of book keeping required to do it as + # a background update. + + # We check if the `current_state_events.membership` is up to date by + # checking if the relevant background update has finished. If it has + # finished we can avoid doing a join against `room_memberships`, which + # speesd things up. + cur.execute( + """SELECT 1 FROM background_updates + WHERE update_name = 'current_state_events_membership' + """ + ) + current_state_membership_up_to_date = not bool(cur.fetchone()) + + # Cheekily drop and recreate indices, as that is faster. + cur.execute("DROP INDEX local_current_membership_idx") + cur.execute("DROP INDEX local_current_membership_room_idx") + + if current_state_membership_up_to_date: + sql = """ + INSERT INTO local_current_membership (room_id, user_id, event_id, membership) + SELECT c.room_id, state_key AS user_id, event_id, c.membership + FROM current_state_events AS c + WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key LIKE ? + """ + else: + # We can't rely on the membership column, so we need to join against + # `room_memberships`. + sql = """ + INSERT INTO local_current_membership (room_id, user_id, event_id, membership) + SELECT c.room_id, state_key AS user_id, event_id, r.membership + FROM current_state_events AS c + INNER JOIN room_memberships AS r USING (event_id) + WHERE type = 'm.room.member' AND state_key LIKE ? + """ + sql = database_engine.convert_param_style(sql) + cur.execute(sql, ("%:" + config.server_name,)) + + cur.execute( + "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" + ) + cur.execute( + "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" + ) + + +def run_create(cur, database_engine, *args, **kwargs): + cur.execute( + """ + CREATE TABLE local_current_membership ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + membership TEXT NOT NULL + )""" + ) + + cur.execute( + "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" + ) + cur.execute( + "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" + ) diff --git a/synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql new file mode 100644 index 0000000000..133d80af35 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql @@ -0,0 +1,21 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- we no longer keep sent outbound device pokes in the db; clear them out +-- so that we don't have to worry about them. +-- +-- This is a sequence scan, but it doesn't take too long. + +DELETE FROM device_lists_outbound_pokes WHERE sent; diff --git a/synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql new file mode 100644 index 0000000000..352a66f5b0 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql @@ -0,0 +1,24 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- We want to start storing the room version independently of +-- `current_state_events` so that we can delete stale entries from it without +-- losing the information. +ALTER TABLE rooms ADD COLUMN room_version TEXT; + + +INSERT into background_updates (update_name, progress_json) + VALUES ('add_rooms_room_version_column', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres new file mode 100644 index 0000000000..c601cff6de --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres @@ -0,0 +1,35 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- when we first added the room_version column, it was populated via a background +-- update. We now need it to be populated before synapse starts, so we populate +-- any remaining rows with a NULL room version now. For servers which have completed +-- the background update, this will be pretty quick. + +-- the following query will set room_version to NULL if no create event is found for +-- the room in current_state_events, and will set it to '1' if a create event with no +-- room_version is found. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json::json->'content'->>'room_version','1') + FROM current_state_events cse INNER JOIN event_json ej USING (event_id) + WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key='' +) WHERE rooms.room_version IS NULL; + +-- we still allow the background update to complete: it has the useful side-effect of +-- populating `rooms` with any missing rooms (based on the current_state_events table). + +-- see also rooms_version_column_2.sql.sqlite which has a copy of the above query, using +-- sqlite syntax for the json extraction. diff --git a/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite new file mode 100644 index 0000000000..335c6f2074 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- see rooms_version_column_2.sql.postgres for details of what's going on here. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1') + FROM current_state_events cse INNER JOIN event_json ej USING (event_id) + WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key='' +) WHERE rooms.room_version IS NULL; diff --git a/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres new file mode 100644 index 0000000000..92aaadde0d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres @@ -0,0 +1,39 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- When we first added the room_version column to the rooms table, it was populated from +-- the current_state_events table. However, there was an issue causing a background +-- update to clean up the current_state_events table for rooms where the server is no +-- longer participating, before that column could be populated. Therefore, some rooms had +-- a NULL room_version. + +-- The rooms_version_column_2.sql.* delta files were introduced to make the populating +-- synchronous instead of running it in a background update, which fixed this issue. +-- However, all of the instances of Synapse installed or updated in the meantime got +-- their rooms table corrupted with NULL room_versions. + +-- This query fishes out the room versions from the create event using the state_events +-- table instead of the current_state_events one, as the former still have all of the +-- create events. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json::json->'content'->>'room_version','1') + FROM state_events se INNER JOIN event_json ej USING (event_id) + WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' + LIMIT 1 +) WHERE rooms.room_version IS NULL; + +-- see also rooms_version_column_3.sql.sqlite which has a copy of the above query, using +-- sqlite syntax for the json extraction. diff --git a/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite new file mode 100644 index 0000000000..e19dab97cb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite @@ -0,0 +1,23 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- see rooms_version_column_3.sql.postgres for details of what's going on here. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1') + FROM state_events se INNER JOIN event_json ej USING (event_id) + WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' + LIMIT 1 +) WHERE rooms.room_version IS NULL; diff --git a/synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql new file mode 100644 index 0000000000..fdc39e9ba5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + /* for some reason, we have accumulated duplicate entries in + * device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + * efficient. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) + VALUES (5800, 'remove_dup_outbound_pokes', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql new file mode 100644 index 0000000000..dcb593fc2d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql @@ -0,0 +1,36 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS ui_auth_sessions( + session_id TEXT NOT NULL, -- The session ID passed to the client. + creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds). + serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse. + clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client. + uri TEXT NOT NULL, -- The URI the UI authentication session is using. + method TEXT NOT NULL, -- The HTTP method the UI authentication session is using. + -- The clientdict, uri, and method make up an tuple that must be immutable + -- throughout the lifetime of the UI Auth session. + description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur. + UNIQUE (session_id) +); + +CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials( + session_id TEXT NOT NULL, -- The corresponding UI Auth session. + stage_type TEXT NOT NULL, -- The stage type. + result TEXT NOT NULL, -- The result of the stage verification, stored as JSON. + UNIQUE (session_id, stage_type), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres new file mode 100644 index 0000000000..aa46eb0e10 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres @@ -0,0 +1,30 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We keep the old table here to enable us to roll back. It doesn't matter +-- that we have dropped all the data here. +TRUNCATE cache_invalidation_stream; + +CREATE TABLE cache_invalidation_stream_by_instance ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + cache_func TEXT NOT NULL, + keys TEXT[], + invalidation_ts BIGINT +); + +CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id); + +CREATE SEQUENCE cache_invalidation_stream_seq; diff --git a/synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py new file mode 100644 index 0000000000..d353f2bcb3 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py @@ -0,0 +1,80 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This migration rebuilds the device_lists_outbound_last_success table without duplicate +entries, and with a UNIQUE index. +""" + +import logging +from io import StringIO + +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.prepare_database import execute_statements_from_stream +from synapse.storage.types import Cursor + +logger = logging.getLogger(__name__) + + +def run_upgrade(*args, **kwargs): + pass + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + # some instances might already have this index, in which case we can skip this + if isinstance(database_engine, PostgresEngine): + cur.execute( + """ + SELECT 1 FROM pg_class WHERE relkind = 'i' + AND relname = 'device_lists_outbound_last_success_unique_idx' + """ + ) + + if cur.rowcount: + logger.info( + "Unique index exists on device_lists_outbound_last_success: " + "skipping rebuild" + ) + return + + logger.info("Rebuilding device_lists_outbound_last_success with unique index") + execute_statements_from_stream(cur, StringIO(_rebuild_commands)) + + +# there might be duplicates, so the easiest way to achieve this is to create a new +# table with the right data, and renaming it into place + +_rebuild_commands = """ +DROP TABLE IF EXISTS device_lists_outbound_last_success_new; + +CREATE TABLE device_lists_outbound_last_success_new ( + destination TEXT NOT NULL, + user_id TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +-- this took about 30 seconds on matrix.org's 16 million rows. +INSERT INTO device_lists_outbound_last_success_new + SELECT destination, user_id, MAX(stream_id) FROM device_lists_outbound_last_success + GROUP BY destination, user_id; + +-- and this another 30 seconds. +CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx + ON device_lists_outbound_last_success_new (destination, user_id); + +DROP TABLE device_lists_outbound_last_success; + +ALTER TABLE device_lists_outbound_last_success_new + RENAME TO device_lists_outbound_last_success; +""" diff --git a/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres new file mode 100644 index 0000000000..597f2ffd3d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The local_media_repository should have files which do not get quarantined, +-- e.g. files from sticker packs. +ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite new file mode 100644 index 0000000000..69db89ac0e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The local_media_repository should have files which do not get quarantined, +-- e.g. files from sticker packs. +ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0; diff --git a/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql b/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql new file mode 100644 index 0000000000..eb57203e46 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +The version of synapse 1.16.0 on pypi incorrectly contained a migration which +added a table called 'local_rejections_stream'. This table is not used, and +we drop it here for anyone who was affected. +*/ + +DROP TABLE IF EXISTS local_rejections_stream; diff --git a/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql new file mode 100644 index 0000000000..1cc2633aad --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We need to store the stream positions by instance in a sharded config world. +-- +-- We default to master as we want the column to be NOT NULL and we correctly +-- reset the instance name to match the config each time we start up. +ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master'; + +CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name); diff --git a/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py b/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py new file mode 100644 index 0000000000..4310ec12ce --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py @@ -0,0 +1,34 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adds a postgres SEQUENCE for generating guest user IDs. +""" + +from synapse.storage.databases.main.registration import ( + find_max_generated_user_id_localpart, +) +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + next_id = find_max_generated_user_id_localpart(cur) + 1 + cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql new file mode 100644 index 0000000000..cade5dcca8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql @@ -0,0 +1,32 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Recalculate the stats for all rooms after the fix to joined_members erroneously +-- incrementing on per-room profile changes. + +-- Note that the populate_stats_process_rooms background update is already set to +-- run if you're upgrading from Synapse <1.0.0. + +-- Additionally, if you've upgraded to v1.18.0 (which doesn't include this fix), +-- this bg job runs, and then update to v1.19.0, you'd end up with only half of +-- your rooms having room stats recalculated after this fix was in place. + +-- So we've switched the old `populate_stats_process_rooms` background job to a +-- no-op, and then kick off a bg job with a new name, but with the same +-- functionality as the old one. This effectively restarts the background job +-- from the beginning, without running it twice in a row, supporting both +-- upgrade usecases. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_stats_process_rooms_2', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql b/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql new file mode 100644 index 0000000000..531b532c73 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Store a boolean value in the events table for whether the event should be counted in +-- the unread_count property of sync responses. +ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/synapse/storage/databases/main/schema/full_schemas/16/application_services.sql b/synapse/storage/databases/main/schema/full_schemas/16/application_services.sql new file mode 100644 index 0000000000..883fcd10b2 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/application_services.sql @@ -0,0 +1,37 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* We used to create tables called application_services and + * application_services_regex, but these are no longer used and are removed in + * delta 54. + */ + + +CREATE TABLE IF NOT EXISTS application_services_state( + as_id TEXT PRIMARY KEY, + state VARCHAR(5), + last_txn INTEGER +); + +CREATE TABLE IF NOT EXISTS application_services_txns( + as_id TEXT NOT NULL, + txn_id INTEGER NOT NULL, + event_ids TEXT NOT NULL, + UNIQUE(as_id, txn_id) +); + +CREATE INDEX application_services_txns_id ON application_services_txns ( + as_id +); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql b/synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql new file mode 100644 index 0000000000..10ce2aa7a0 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql @@ -0,0 +1,70 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* We used to create tables called event_destinations and + * state_forward_extremities, but these are no longer used and are removed in + * delta 54. + */ + +CREATE TABLE IF NOT EXISTS event_forward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, room_id) +); + +CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_backward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, room_id) +); + +CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_edges( + event_id TEXT NOT NULL, + prev_event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_state BOOL NOT NULL, -- true if this is a prev_state edge rather than a regular + -- event dag edge. + UNIQUE (event_id, prev_event_id, room_id, is_state) +); + +CREATE INDEX ev_edges_id ON event_edges(event_id); +CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); + + +CREATE TABLE IF NOT EXISTS room_depth( + room_id TEXT NOT NULL, + min_depth INTEGER NOT NULL, + UNIQUE (room_id) +); + +CREATE INDEX room_depth_room ON room_depth(room_id); + +CREATE TABLE IF NOT EXISTS event_auth( + event_id TEXT NOT NULL, + auth_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, auth_id, room_id) +); + +CREATE INDEX evauth_edges_id ON event_auth(event_id); +CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql b/synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql new file mode 100644 index 0000000000..95826da431 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql @@ -0,0 +1,38 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + /* We used to create tables called event_content_hashes and event_edge_hashes, + * but these are no longer used and are removed in delta 54. + */ + +CREATE TABLE IF NOT EXISTS event_reference_hashes ( + event_id TEXT, + algorithm TEXT, + hash bytea, + UNIQUE (event_id, algorithm) +); + +CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); + + +CREATE TABLE IF NOT EXISTS event_signatures ( + event_id TEXT, + signature_name TEXT, + key_id TEXT, + signature bytea, + UNIQUE (event_id, signature_name, key_id) +); + +CREATE INDEX event_signatures_id ON event_signatures(event_id); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/im.sql b/synapse/storage/databases/main/schema/full_schemas/16/im.sql new file mode 100644 index 0000000000..a1a2aa8e5b --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/im.sql @@ -0,0 +1,120 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* We used to create tables called room_hosts and feedback, + * but these are no longer used and are removed in delta 54. + */ + +CREATE TABLE IF NOT EXISTS events( + stream_ordering INTEGER PRIMARY KEY, + topological_ordering BIGINT NOT NULL, + event_id TEXT NOT NULL, + type TEXT NOT NULL, + room_id TEXT NOT NULL, + + -- 'content' used to be created NULLable, but as of delta 50 we drop that constraint. + -- the hack we use to drop the constraint doesn't work for an in-memory sqlite + -- database, which breaks the sytests. Hence, we no longer make it nullable. + content TEXT, + + unrecognized_keys TEXT, + processed BOOL NOT NULL, + outlier BOOL NOT NULL, + depth BIGINT DEFAULT 0 NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX events_stream_ordering ON events (stream_ordering); +CREATE INDEX events_topological_ordering ON events (topological_ordering); +CREATE INDEX events_order ON events (topological_ordering, stream_ordering); +CREATE INDEX events_room_id ON events (room_id); +CREATE INDEX events_order_room ON events ( + room_id, topological_ordering, stream_ordering +); + + +CREATE TABLE IF NOT EXISTS event_json( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + internal_metadata TEXT NOT NULL, + json TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX event_json_room_id ON event_json(room_id); + + +CREATE TABLE IF NOT EXISTS state_events( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + prev_state TEXT, + UNIQUE (event_id) +); + +CREATE INDEX state_events_room_id ON state_events (room_id); +CREATE INDEX state_events_type ON state_events (type); +CREATE INDEX state_events_state_key ON state_events (state_key); + + +CREATE TABLE IF NOT EXISTS current_state_events( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + UNIQUE (event_id), + UNIQUE (room_id, type, state_key) +); + +CREATE INDEX current_state_events_room_id ON current_state_events (room_id); +CREATE INDEX current_state_events_type ON current_state_events (type); +CREATE INDEX current_state_events_state_key ON current_state_events (state_key); + +CREATE TABLE IF NOT EXISTS room_memberships( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + sender TEXT NOT NULL, + room_id TEXT NOT NULL, + membership TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX room_memberships_room_id ON room_memberships (room_id); +CREATE INDEX room_memberships_user_id ON room_memberships (user_id); + +CREATE TABLE IF NOT EXISTS topics( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + topic TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX topics_room_id ON topics(room_id); + +CREATE TABLE IF NOT EXISTS room_names( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + name TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX room_names_room_id ON room_names(room_id); + +CREATE TABLE IF NOT EXISTS rooms( + room_id TEXT PRIMARY KEY NOT NULL, + is_public BOOL, + creator TEXT +); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/keys.sql b/synapse/storage/databases/main/schema/full_schemas/16/keys.sql new file mode 100644 index 0000000000..11cdffdbb3 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/keys.sql @@ -0,0 +1,26 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- we used to create a table called server_tls_certificates, but this is no +-- longer used, and is removed in delta 54. + +CREATE TABLE IF NOT EXISTS server_signature_keys( + server_name TEXT, -- Server name. + key_id TEXT, -- Key version. + from_server TEXT, -- Which key server the key was fetched form. + ts_added_ms BIGINT, -- When the key was added. + verify_key bytea, -- NACL verification key. + UNIQUE (server_name, key_id) +); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql b/synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql new file mode 100644 index 0000000000..8f3759bb2a --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql @@ -0,0 +1,68 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS local_media_repository ( + media_id TEXT, -- The id used to refer to the media. + media_type TEXT, -- The MIME-type of the media. + media_length INTEGER, -- Length of the media in bytes. + created_ts BIGINT, -- When the content was uploaded in ms. + upload_name TEXT, -- The name the media was uploaded with. + user_id TEXT, -- The user who uploaded the file. + UNIQUE (media_id) +); + +CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( + media_id TEXT, -- The id used to refer to the media. + thumbnail_width INTEGER, -- The width of the thumbnail in pixels. + thumbnail_height INTEGER, -- The height of the thumbnail in pixels. + thumbnail_type TEXT, -- The MIME-type of the thumbnail. + thumbnail_method TEXT, -- The method used to make the thumbnail. + thumbnail_length INTEGER, -- The length of the thumbnail in bytes. + UNIQUE ( + media_id, thumbnail_width, thumbnail_height, thumbnail_type + ) +); + +CREATE INDEX local_media_repository_thumbnails_media_id + ON local_media_repository_thumbnails (media_id); + +CREATE TABLE IF NOT EXISTS remote_media_cache ( + media_origin TEXT, -- The remote HS the media came from. + media_id TEXT, -- The id used to refer to the media on that server. + media_type TEXT, -- The MIME-type of the media. + created_ts BIGINT, -- When the content was uploaded in ms. + upload_name TEXT, -- The name the media was uploaded with. + media_length INTEGER, -- Length of the media in bytes. + filesystem_id TEXT, -- The name used to store the media on disk. + UNIQUE (media_origin, media_id) +); + +CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( + media_origin TEXT, -- The remote HS the media came from. + media_id TEXT, -- The id used to refer to the media. + thumbnail_width INTEGER, -- The width of the thumbnail in pixels. + thumbnail_height INTEGER, -- The height of the thumbnail in pixels. + thumbnail_method TEXT, -- The method used to make the thumbnail + thumbnail_type TEXT, -- The MIME-type of the thumbnail. + thumbnail_length INTEGER, -- The length of the thumbnail in bytes. + filesystem_id TEXT, -- The name used to store the media on disk. + UNIQUE ( + media_origin, media_id, thumbnail_width, thumbnail_height, + thumbnail_type + ) +); + +CREATE INDEX remote_media_cache_thumbnails_media_id + ON remote_media_cache_thumbnails (media_id); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/presence.sql b/synapse/storage/databases/main/schema/full_schemas/16/presence.sql new file mode 100644 index 0000000000..01d2d8f833 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/presence.sql @@ -0,0 +1,32 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS presence( + user_id TEXT NOT NULL, + state VARCHAR(20), + status_msg TEXT, + mtime BIGINT, -- miliseconds since last state change + UNIQUE (user_id) +); + +-- For each of /my/ users which possibly-remote users are allowed to see their +-- presence state +CREATE TABLE IF NOT EXISTS presence_allow_inbound( + observed_user_id TEXT NOT NULL, + observer_user_id TEXT NOT NULL, -- a UserID, + UNIQUE (observed_user_id, observer_user_id) +); + +-- We used to create a table called presence_list, but this is no longer used +-- and is removed in delta 54. \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/full_schemas/16/profiles.sql b/synapse/storage/databases/main/schema/full_schemas/16/profiles.sql new file mode 100644 index 0000000000..c04f4747d9 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/profiles.sql @@ -0,0 +1,20 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS profiles( + user_id TEXT NOT NULL, + displayname TEXT, + avatar_url TEXT, + UNIQUE(user_id) +); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/push.sql b/synapse/storage/databases/main/schema/full_schemas/16/push.sql new file mode 100644 index 0000000000..e44465cf45 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/push.sql @@ -0,0 +1,74 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + UNIQUE (event_id) +); + +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey bytea NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data bytea, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class SMALLINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX push_rules_user_name on push_rules (user_name); + +CREATE TABLE IF NOT EXISTS user_filters( + user_id TEXT, + filter_id BIGINT, + filter_json bytea +); + +CREATE INDEX user_filters_by_user_id_filter_id ON user_filters( + user_id, filter_id +); + +CREATE TABLE IF NOT EXISTS push_rules_enable ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + enabled SMALLINT, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/redactions.sql b/synapse/storage/databases/main/schema/full_schemas/16/redactions.sql new file mode 100644 index 0000000000..318f0d9aa5 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/redactions.sql @@ -0,0 +1,22 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS redactions ( + event_id TEXT NOT NULL, + redacts TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX redactions_event_id ON redactions (event_id); +CREATE INDEX redactions_redacts ON redactions (redacts); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql b/synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql new file mode 100644 index 0000000000..d47da3b12f --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql @@ -0,0 +1,29 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS room_aliases( + room_alias TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (room_alias) +); + +CREATE INDEX room_aliases_id ON room_aliases(room_id); + +CREATE TABLE IF NOT EXISTS room_alias_servers( + room_alias TEXT NOT NULL, + server TEXT NOT NULL +); + +CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/state.sql b/synapse/storage/databases/main/schema/full_schemas/16/state.sql new file mode 100644 index 0000000000..96391a8f0e --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/state.sql @@ -0,0 +1,40 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS state_groups( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS state_groups_state( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS event_to_state_groups( + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX state_groups_id ON state_groups(id); + +CREATE INDEX state_groups_state_id ON state_groups_state(state_group); +CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key); +CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/transactions.sql b/synapse/storage/databases/main/schema/full_schemas/16/transactions.sql new file mode 100644 index 0000000000..17e67bedac --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/transactions.sql @@ -0,0 +1,44 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- Stores what transaction ids we have received and what our response was +CREATE TABLE IF NOT EXISTS received_transactions( + transaction_id TEXT, + origin TEXT, + ts BIGINT, + response_code INTEGER, + response_json bytea, + has_been_referenced smallint default 0, -- Whether thishas been referenced by a prev_tx + UNIQUE (transaction_id, origin) +); + +CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0; + +-- For sent transactions only. +CREATE TABLE IF NOT EXISTS transaction_id_to_pdu( + transaction_id INTEGER, + destination TEXT, + pdu_id TEXT, + pdu_origin TEXT, + UNIQUE (transaction_id, destination) +); + +CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); + +-- To track destination health +CREATE TABLE IF NOT EXISTS destinations( + destination TEXT PRIMARY KEY, + retry_last_ts BIGINT, + retry_interval INTEGER +); diff --git a/synapse/storage/databases/main/schema/full_schemas/16/users.sql b/synapse/storage/databases/main/schema/full_schemas/16/users.sql new file mode 100644 index 0000000000..f013aa8b18 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/16/users.sql @@ -0,0 +1,42 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS users( + name TEXT, + password_hash TEXT, + creation_ts BIGINT, + admin SMALLINT DEFAULT 0 NOT NULL, + UNIQUE(name) +); + +CREATE TABLE IF NOT EXISTS access_tokens( + id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + last_used BIGINT, + UNIQUE(token) +); + +CREATE TABLE IF NOT EXISTS user_ips ( + user_id TEXT NOT NULL, + access_token TEXT NOT NULL, + device_id TEXT, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + last_seen BIGINT NOT NULL +); + +CREATE INDEX user_ips_user ON user_ips(user_id); +CREATE INDEX user_ips_user_ip ON user_ips(user_id, access_token, ip); diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres new file mode 100644 index 0000000000..889a9a0ce4 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres @@ -0,0 +1,1983 @@ + + + + + +CREATE TABLE access_tokens ( + id bigint NOT NULL, + user_id text NOT NULL, + device_id text, + token text NOT NULL, + last_used bigint +); + + + +CREATE TABLE account_data ( + user_id text NOT NULL, + account_data_type text NOT NULL, + stream_id bigint NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE account_data_max_stream_id ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_id bigint NOT NULL, + CONSTRAINT private_user_data_max_stream_id_lock_check CHECK ((lock = 'X'::bpchar)) +); + + + +CREATE TABLE account_validity ( + user_id text NOT NULL, + expiration_ts_ms bigint NOT NULL, + email_sent boolean NOT NULL, + renewal_token text +); + + + +CREATE TABLE application_services_state ( + as_id text NOT NULL, + state character varying(5), + last_txn integer +); + + + +CREATE TABLE application_services_txns ( + as_id text NOT NULL, + txn_id integer NOT NULL, + event_ids text NOT NULL +); + + + +CREATE TABLE appservice_room_list ( + appservice_id text NOT NULL, + network_id text NOT NULL, + room_id text NOT NULL +); + + + +CREATE TABLE appservice_stream_position ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_ordering bigint, + CONSTRAINT appservice_stream_position_lock_check CHECK ((lock = 'X'::bpchar)) +); + + +CREATE TABLE blocked_rooms ( + room_id text NOT NULL, + user_id text NOT NULL +); + + + +CREATE TABLE cache_invalidation_stream ( + stream_id bigint, + cache_func text, + keys text[], + invalidation_ts bigint +); + + + +CREATE TABLE current_state_delta_stream ( + stream_id bigint NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL, + event_id text, + prev_event_id text +); + + + +CREATE TABLE current_state_events ( + event_id text NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL +); + + + +CREATE TABLE deleted_pushers ( + stream_id bigint NOT NULL, + app_id text NOT NULL, + pushkey text NOT NULL, + user_id text NOT NULL +); + + + +CREATE TABLE destinations ( + destination text NOT NULL, + retry_last_ts bigint, + retry_interval integer +); + + + +CREATE TABLE device_federation_inbox ( + origin text NOT NULL, + message_id text NOT NULL, + received_ts bigint NOT NULL +); + + + +CREATE TABLE device_federation_outbox ( + destination text NOT NULL, + stream_id bigint NOT NULL, + queued_ts bigint NOT NULL, + messages_json text NOT NULL +); + + + +CREATE TABLE device_inbox ( + user_id text NOT NULL, + device_id text NOT NULL, + stream_id bigint NOT NULL, + message_json text NOT NULL +); + + + +CREATE TABLE device_lists_outbound_last_success ( + destination text NOT NULL, + user_id text NOT NULL, + stream_id bigint NOT NULL +); + + + +CREATE TABLE device_lists_outbound_pokes ( + destination text NOT NULL, + stream_id bigint NOT NULL, + user_id text NOT NULL, + device_id text NOT NULL, + sent boolean NOT NULL, + ts bigint NOT NULL +); + + + +CREATE TABLE device_lists_remote_cache ( + user_id text NOT NULL, + device_id text NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE device_lists_remote_extremeties ( + user_id text NOT NULL, + stream_id text NOT NULL +); + + + +CREATE TABLE device_lists_stream ( + stream_id bigint NOT NULL, + user_id text NOT NULL, + device_id text NOT NULL +); + + + +CREATE TABLE device_max_stream_id ( + stream_id bigint NOT NULL +); + + + +CREATE TABLE devices ( + user_id text NOT NULL, + device_id text NOT NULL, + display_name text +); + + + +CREATE TABLE e2e_device_keys_json ( + user_id text NOT NULL, + device_id text NOT NULL, + ts_added_ms bigint NOT NULL, + key_json text NOT NULL +); + + + +CREATE TABLE e2e_one_time_keys_json ( + user_id text NOT NULL, + device_id text NOT NULL, + algorithm text NOT NULL, + key_id text NOT NULL, + ts_added_ms bigint NOT NULL, + key_json text NOT NULL +); + + + +CREATE TABLE e2e_room_keys ( + user_id text NOT NULL, + room_id text NOT NULL, + session_id text NOT NULL, + version bigint NOT NULL, + first_message_index integer, + forwarded_count integer, + is_verified boolean, + session_data text NOT NULL +); + + + +CREATE TABLE e2e_room_keys_versions ( + user_id text NOT NULL, + version bigint NOT NULL, + algorithm text NOT NULL, + auth_data text NOT NULL, + deleted smallint DEFAULT 0 NOT NULL +); + + + +CREATE TABLE erased_users ( + user_id text NOT NULL +); + + + +CREATE TABLE event_auth ( + event_id text NOT NULL, + auth_id text NOT NULL, + room_id text NOT NULL +); + + + +CREATE TABLE event_backward_extremities ( + event_id text NOT NULL, + room_id text NOT NULL +); + + + +CREATE TABLE event_edges ( + event_id text NOT NULL, + prev_event_id text NOT NULL, + room_id text NOT NULL, + is_state boolean NOT NULL +); + + + +CREATE TABLE event_forward_extremities ( + event_id text NOT NULL, + room_id text NOT NULL +); + + + +CREATE TABLE event_json ( + event_id text NOT NULL, + room_id text NOT NULL, + internal_metadata text NOT NULL, + json text NOT NULL, + format_version integer +); + + + +CREATE TABLE event_push_actions ( + room_id text NOT NULL, + event_id text NOT NULL, + user_id text NOT NULL, + profile_tag character varying(32), + actions text NOT NULL, + topological_ordering bigint, + stream_ordering bigint, + notif smallint, + highlight smallint +); + + + +CREATE TABLE event_push_actions_staging ( + event_id text NOT NULL, + user_id text NOT NULL, + actions text NOT NULL, + notif smallint NOT NULL, + highlight smallint NOT NULL +); + + + +CREATE TABLE event_push_summary ( + user_id text NOT NULL, + room_id text NOT NULL, + notif_count bigint NOT NULL, + stream_ordering bigint NOT NULL +); + + + +CREATE TABLE event_push_summary_stream_ordering ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_ordering bigint NOT NULL, + CONSTRAINT event_push_summary_stream_ordering_lock_check CHECK ((lock = 'X'::bpchar)) +); + + + +CREATE TABLE event_reference_hashes ( + event_id text, + algorithm text, + hash bytea +); + + + +CREATE TABLE event_relations ( + event_id text NOT NULL, + relates_to_id text NOT NULL, + relation_type text NOT NULL, + aggregation_key text +); + + + +CREATE TABLE event_reports ( + id bigint NOT NULL, + received_ts bigint NOT NULL, + room_id text NOT NULL, + event_id text NOT NULL, + user_id text NOT NULL, + reason text, + content text +); + + + +CREATE TABLE event_search ( + event_id text, + room_id text, + sender text, + key text, + vector tsvector, + origin_server_ts bigint, + stream_ordering bigint +); + + + +CREATE TABLE event_to_state_groups ( + event_id text NOT NULL, + state_group bigint NOT NULL +); + + + +CREATE TABLE events ( + stream_ordering integer NOT NULL, + topological_ordering bigint NOT NULL, + event_id text NOT NULL, + type text NOT NULL, + room_id text NOT NULL, + content text, + unrecognized_keys text, + processed boolean NOT NULL, + outlier boolean NOT NULL, + depth bigint DEFAULT 0 NOT NULL, + origin_server_ts bigint, + received_ts bigint, + sender text, + contains_url boolean +); + + + +CREATE TABLE ex_outlier_stream ( + event_stream_ordering bigint NOT NULL, + event_id text NOT NULL, + state_group bigint NOT NULL +); + + + +CREATE TABLE federation_stream_position ( + type text NOT NULL, + stream_id integer NOT NULL +); + + + +CREATE TABLE group_attestations_remote ( + group_id text NOT NULL, + user_id text NOT NULL, + valid_until_ms bigint NOT NULL, + attestation_json text NOT NULL +); + + + +CREATE TABLE group_attestations_renewals ( + group_id text NOT NULL, + user_id text NOT NULL, + valid_until_ms bigint NOT NULL +); + + + +CREATE TABLE group_invites ( + group_id text NOT NULL, + user_id text NOT NULL +); + + + +CREATE TABLE group_roles ( + group_id text NOT NULL, + role_id text NOT NULL, + profile text NOT NULL, + is_public boolean NOT NULL +); + + + +CREATE TABLE group_room_categories ( + group_id text NOT NULL, + category_id text NOT NULL, + profile text NOT NULL, + is_public boolean NOT NULL +); + + + +CREATE TABLE group_rooms ( + group_id text NOT NULL, + room_id text NOT NULL, + is_public boolean NOT NULL +); + + + +CREATE TABLE group_summary_roles ( + group_id text NOT NULL, + role_id text NOT NULL, + role_order bigint NOT NULL, + CONSTRAINT group_summary_roles_role_order_check CHECK ((role_order > 0)) +); + + + +CREATE TABLE group_summary_room_categories ( + group_id text NOT NULL, + category_id text NOT NULL, + cat_order bigint NOT NULL, + CONSTRAINT group_summary_room_categories_cat_order_check CHECK ((cat_order > 0)) +); + + + +CREATE TABLE group_summary_rooms ( + group_id text NOT NULL, + room_id text NOT NULL, + category_id text NOT NULL, + room_order bigint NOT NULL, + is_public boolean NOT NULL, + CONSTRAINT group_summary_rooms_room_order_check CHECK ((room_order > 0)) +); + + + +CREATE TABLE group_summary_users ( + group_id text NOT NULL, + user_id text NOT NULL, + role_id text NOT NULL, + user_order bigint NOT NULL, + is_public boolean NOT NULL +); + + + +CREATE TABLE group_users ( + group_id text NOT NULL, + user_id text NOT NULL, + is_admin boolean NOT NULL, + is_public boolean NOT NULL +); + + + +CREATE TABLE groups ( + group_id text NOT NULL, + name text, + avatar_url text, + short_description text, + long_description text, + is_public boolean NOT NULL, + join_policy text DEFAULT 'invite'::text NOT NULL +); + + + +CREATE TABLE guest_access ( + event_id text NOT NULL, + room_id text NOT NULL, + guest_access text NOT NULL +); + + + +CREATE TABLE history_visibility ( + event_id text NOT NULL, + room_id text NOT NULL, + history_visibility text NOT NULL +); + + + +CREATE TABLE local_group_membership ( + group_id text NOT NULL, + user_id text NOT NULL, + is_admin boolean NOT NULL, + membership text NOT NULL, + is_publicised boolean NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE local_group_updates ( + stream_id bigint NOT NULL, + group_id text NOT NULL, + user_id text NOT NULL, + type text NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE local_invites ( + stream_id bigint NOT NULL, + inviter text NOT NULL, + invitee text NOT NULL, + event_id text NOT NULL, + room_id text NOT NULL, + locally_rejected text, + replaced_by text +); + + + +CREATE TABLE local_media_repository ( + media_id text, + media_type text, + media_length integer, + created_ts bigint, + upload_name text, + user_id text, + quarantined_by text, + url_cache text, + last_access_ts bigint +); + + + +CREATE TABLE local_media_repository_thumbnails ( + media_id text, + thumbnail_width integer, + thumbnail_height integer, + thumbnail_type text, + thumbnail_method text, + thumbnail_length integer +); + + + +CREATE TABLE local_media_repository_url_cache ( + url text, + response_code integer, + etag text, + expires_ts bigint, + og text, + media_id text, + download_ts bigint +); + + + +CREATE TABLE monthly_active_users ( + user_id text NOT NULL, + "timestamp" bigint NOT NULL +); + + + +CREATE TABLE open_id_tokens ( + token text NOT NULL, + ts_valid_until_ms bigint NOT NULL, + user_id text NOT NULL +); + + + +CREATE TABLE presence ( + user_id text NOT NULL, + state character varying(20), + status_msg text, + mtime bigint +); + + + +CREATE TABLE presence_allow_inbound ( + observed_user_id text NOT NULL, + observer_user_id text NOT NULL +); + + + +CREATE TABLE presence_stream ( + stream_id bigint, + user_id text, + state text, + last_active_ts bigint, + last_federation_update_ts bigint, + last_user_sync_ts bigint, + status_msg text, + currently_active boolean +); + + + +CREATE TABLE profiles ( + user_id text NOT NULL, + displayname text, + avatar_url text +); + + + +CREATE TABLE public_room_list_stream ( + stream_id bigint NOT NULL, + room_id text NOT NULL, + visibility boolean NOT NULL, + appservice_id text, + network_id text +); + + + +CREATE TABLE push_rules ( + id bigint NOT NULL, + user_name text NOT NULL, + rule_id text NOT NULL, + priority_class smallint NOT NULL, + priority integer DEFAULT 0 NOT NULL, + conditions text NOT NULL, + actions text NOT NULL +); + + + +CREATE TABLE push_rules_enable ( + id bigint NOT NULL, + user_name text NOT NULL, + rule_id text NOT NULL, + enabled smallint +); + + + +CREATE TABLE push_rules_stream ( + stream_id bigint NOT NULL, + event_stream_ordering bigint NOT NULL, + user_id text NOT NULL, + rule_id text NOT NULL, + op text NOT NULL, + priority_class smallint, + priority integer, + conditions text, + actions text +); + + + +CREATE TABLE pusher_throttle ( + pusher bigint NOT NULL, + room_id text NOT NULL, + last_sent_ts bigint, + throttle_ms bigint +); + + + +CREATE TABLE pushers ( + id bigint NOT NULL, + user_name text NOT NULL, + access_token bigint, + profile_tag text NOT NULL, + kind text NOT NULL, + app_id text NOT NULL, + app_display_name text NOT NULL, + device_display_name text NOT NULL, + pushkey text NOT NULL, + ts bigint NOT NULL, + lang text, + data text, + last_stream_ordering integer, + last_success bigint, + failing_since bigint +); + + + +CREATE TABLE ratelimit_override ( + user_id text NOT NULL, + messages_per_second bigint, + burst_count bigint +); + + + +CREATE TABLE receipts_graph ( + room_id text NOT NULL, + receipt_type text NOT NULL, + user_id text NOT NULL, + event_ids text NOT NULL, + data text NOT NULL +); + + + +CREATE TABLE receipts_linearized ( + stream_id bigint NOT NULL, + room_id text NOT NULL, + receipt_type text NOT NULL, + user_id text NOT NULL, + event_id text NOT NULL, + data text NOT NULL +); + + + +CREATE TABLE received_transactions ( + transaction_id text, + origin text, + ts bigint, + response_code integer, + response_json bytea, + has_been_referenced smallint DEFAULT 0 +); + + + +CREATE TABLE redactions ( + event_id text NOT NULL, + redacts text NOT NULL +); + + + +CREATE TABLE rejections ( + event_id text NOT NULL, + reason text NOT NULL, + last_check text NOT NULL +); + + + +CREATE TABLE remote_media_cache ( + media_origin text, + media_id text, + media_type text, + created_ts bigint, + upload_name text, + media_length integer, + filesystem_id text, + last_access_ts bigint, + quarantined_by text +); + + + +CREATE TABLE remote_media_cache_thumbnails ( + media_origin text, + media_id text, + thumbnail_width integer, + thumbnail_height integer, + thumbnail_method text, + thumbnail_type text, + thumbnail_length integer, + filesystem_id text +); + + + +CREATE TABLE remote_profile_cache ( + user_id text NOT NULL, + displayname text, + avatar_url text, + last_check bigint NOT NULL +); + + + +CREATE TABLE room_account_data ( + user_id text NOT NULL, + room_id text NOT NULL, + account_data_type text NOT NULL, + stream_id bigint NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE room_alias_servers ( + room_alias text NOT NULL, + server text NOT NULL +); + + + +CREATE TABLE room_aliases ( + room_alias text NOT NULL, + room_id text NOT NULL, + creator text +); + + + +CREATE TABLE room_depth ( + room_id text NOT NULL, + min_depth integer NOT NULL +); + + + +CREATE TABLE room_memberships ( + event_id text NOT NULL, + user_id text NOT NULL, + sender text NOT NULL, + room_id text NOT NULL, + membership text NOT NULL, + forgotten integer DEFAULT 0, + display_name text, + avatar_url text +); + + + +CREATE TABLE room_names ( + event_id text NOT NULL, + room_id text NOT NULL, + name text NOT NULL +); + + + +CREATE TABLE room_state ( + room_id text NOT NULL, + join_rules text, + history_visibility text, + encryption text, + name text, + topic text, + avatar text, + canonical_alias text +); + + + +CREATE TABLE room_stats ( + room_id text NOT NULL, + ts bigint NOT NULL, + bucket_size integer NOT NULL, + current_state_events integer NOT NULL, + joined_members integer NOT NULL, + invited_members integer NOT NULL, + left_members integer NOT NULL, + banned_members integer NOT NULL, + state_events integer NOT NULL +); + + + +CREATE TABLE room_stats_earliest_token ( + room_id text NOT NULL, + token bigint NOT NULL +); + + + +CREATE TABLE room_tags ( + user_id text NOT NULL, + room_id text NOT NULL, + tag text NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE room_tags_revisions ( + user_id text NOT NULL, + room_id text NOT NULL, + stream_id bigint NOT NULL +); + + + +CREATE TABLE rooms ( + room_id text NOT NULL, + is_public boolean, + creator text +); + + + +CREATE TABLE server_keys_json ( + server_name text NOT NULL, + key_id text NOT NULL, + from_server text NOT NULL, + ts_added_ms bigint NOT NULL, + ts_valid_until_ms bigint NOT NULL, + key_json bytea NOT NULL +); + + + +CREATE TABLE server_signature_keys ( + server_name text, + key_id text, + from_server text, + ts_added_ms bigint, + verify_key bytea, + ts_valid_until_ms bigint +); + + + +CREATE TABLE state_events ( + event_id text NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL, + prev_state text +); + + + +CREATE TABLE stats_stream_pos ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_id bigint, + CONSTRAINT stats_stream_pos_lock_check CHECK ((lock = 'X'::bpchar)) +); + + + +CREATE TABLE stream_ordering_to_exterm ( + stream_ordering bigint NOT NULL, + room_id text NOT NULL, + event_id text NOT NULL +); + + + +CREATE TABLE threepid_guest_access_tokens ( + medium text, + address text, + guest_access_token text, + first_inviter text +); + + + +CREATE TABLE topics ( + event_id text NOT NULL, + room_id text NOT NULL, + topic text NOT NULL +); + + + +CREATE TABLE user_daily_visits ( + user_id text NOT NULL, + device_id text, + "timestamp" bigint NOT NULL +); + + + +CREATE TABLE user_directory ( + user_id text NOT NULL, + room_id text, + display_name text, + avatar_url text +); + + + +CREATE TABLE user_directory_search ( + user_id text NOT NULL, + vector tsvector +); + + + +CREATE TABLE user_directory_stream_pos ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_id bigint, + CONSTRAINT user_directory_stream_pos_lock_check CHECK ((lock = 'X'::bpchar)) +); + + + +CREATE TABLE user_filters ( + user_id text, + filter_id bigint, + filter_json bytea +); + + + +CREATE TABLE user_ips ( + user_id text NOT NULL, + access_token text NOT NULL, + device_id text, + ip text NOT NULL, + user_agent text NOT NULL, + last_seen bigint NOT NULL +); + + + +CREATE TABLE user_stats ( + user_id text NOT NULL, + ts bigint NOT NULL, + bucket_size integer NOT NULL, + public_rooms integer NOT NULL, + private_rooms integer NOT NULL +); + + + +CREATE TABLE user_threepid_id_server ( + user_id text NOT NULL, + medium text NOT NULL, + address text NOT NULL, + id_server text NOT NULL +); + + + +CREATE TABLE user_threepids ( + user_id text NOT NULL, + medium text NOT NULL, + address text NOT NULL, + validated_at bigint NOT NULL, + added_at bigint NOT NULL +); + + + +CREATE TABLE users ( + name text, + password_hash text, + creation_ts bigint, + admin smallint DEFAULT 0 NOT NULL, + upgrade_ts bigint, + is_guest smallint DEFAULT 0 NOT NULL, + appservice_id text, + consent_version text, + consent_server_notice_sent text, + user_type text +); + + + +CREATE TABLE users_in_public_rooms ( + user_id text NOT NULL, + room_id text NOT NULL +); + + + +CREATE TABLE users_pending_deactivation ( + user_id text NOT NULL +); + + + +CREATE TABLE users_who_share_private_rooms ( + user_id text NOT NULL, + other_user_id text NOT NULL, + room_id text NOT NULL +); + + + +ALTER TABLE ONLY access_tokens + ADD CONSTRAINT access_tokens_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY access_tokens + ADD CONSTRAINT access_tokens_token_key UNIQUE (token); + + + +ALTER TABLE ONLY account_data + ADD CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type); + + + +ALTER TABLE ONLY account_validity + ADD CONSTRAINT account_validity_pkey PRIMARY KEY (user_id); + + + +ALTER TABLE ONLY application_services_state + ADD CONSTRAINT application_services_state_pkey PRIMARY KEY (as_id); + + + +ALTER TABLE ONLY application_services_txns + ADD CONSTRAINT application_services_txns_as_id_txn_id_key UNIQUE (as_id, txn_id); + + + +ALTER TABLE ONLY appservice_stream_position + ADD CONSTRAINT appservice_stream_position_lock_key UNIQUE (lock); + + + +ALTER TABLE ONLY current_state_events + ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY current_state_events + ADD CONSTRAINT current_state_events_room_id_type_state_key_key UNIQUE (room_id, type, state_key); + + + +ALTER TABLE ONLY destinations + ADD CONSTRAINT destinations_pkey PRIMARY KEY (destination); + + + +ALTER TABLE ONLY devices + ADD CONSTRAINT device_uniqueness UNIQUE (user_id, device_id); + + + +ALTER TABLE ONLY e2e_device_keys_json + ADD CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id); + + + +ALTER TABLE ONLY e2e_one_time_keys_json + ADD CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id); + + + +ALTER TABLE ONLY event_backward_extremities + ADD CONSTRAINT event_backward_extremities_event_id_room_id_key UNIQUE (event_id, room_id); + + + +ALTER TABLE ONLY event_edges + ADD CONSTRAINT event_edges_event_id_prev_event_id_room_id_is_state_key UNIQUE (event_id, prev_event_id, room_id, is_state); + + + +ALTER TABLE ONLY event_forward_extremities + ADD CONSTRAINT event_forward_extremities_event_id_room_id_key UNIQUE (event_id, room_id); + + + +ALTER TABLE ONLY event_push_actions + ADD CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag); + + + +ALTER TABLE ONLY event_json + ADD CONSTRAINT event_json_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY event_push_summary_stream_ordering + ADD CONSTRAINT event_push_summary_stream_ordering_lock_key UNIQUE (lock); + + + +ALTER TABLE ONLY event_reference_hashes + ADD CONSTRAINT event_reference_hashes_event_id_algorithm_key UNIQUE (event_id, algorithm); + + + +ALTER TABLE ONLY event_reports + ADD CONSTRAINT event_reports_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY event_to_state_groups + ADD CONSTRAINT event_to_state_groups_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY events + ADD CONSTRAINT events_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY events + ADD CONSTRAINT events_pkey PRIMARY KEY (stream_ordering); + + + +ALTER TABLE ONLY ex_outlier_stream + ADD CONSTRAINT ex_outlier_stream_pkey PRIMARY KEY (event_stream_ordering); + + + +ALTER TABLE ONLY group_roles + ADD CONSTRAINT group_roles_group_id_role_id_key UNIQUE (group_id, role_id); + + + +ALTER TABLE ONLY group_room_categories + ADD CONSTRAINT group_room_categories_group_id_category_id_key UNIQUE (group_id, category_id); + + + +ALTER TABLE ONLY group_summary_roles + ADD CONSTRAINT group_summary_roles_group_id_role_id_role_order_key UNIQUE (group_id, role_id, role_order); + + + +ALTER TABLE ONLY group_summary_room_categories + ADD CONSTRAINT group_summary_room_categories_group_id_category_id_cat_orde_key UNIQUE (group_id, category_id, cat_order); + + + +ALTER TABLE ONLY group_summary_rooms + ADD CONSTRAINT group_summary_rooms_group_id_category_id_room_id_room_order_key UNIQUE (group_id, category_id, room_id, room_order); + + + +ALTER TABLE ONLY guest_access + ADD CONSTRAINT guest_access_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY history_visibility + ADD CONSTRAINT history_visibility_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY local_media_repository + ADD CONSTRAINT local_media_repository_media_id_key UNIQUE (media_id); + + + +ALTER TABLE ONLY local_media_repository_thumbnails + ADD CONSTRAINT local_media_repository_thumbn_media_id_thumbnail_width_thum_key UNIQUE (media_id, thumbnail_width, thumbnail_height, thumbnail_type); + + + +ALTER TABLE ONLY user_threepids + ADD CONSTRAINT medium_address UNIQUE (medium, address); + + + +ALTER TABLE ONLY open_id_tokens + ADD CONSTRAINT open_id_tokens_pkey PRIMARY KEY (token); + + + +ALTER TABLE ONLY presence_allow_inbound + ADD CONSTRAINT presence_allow_inbound_observed_user_id_observer_user_id_key UNIQUE (observed_user_id, observer_user_id); + + + +ALTER TABLE ONLY presence + ADD CONSTRAINT presence_user_id_key UNIQUE (user_id); + + + +ALTER TABLE ONLY account_data_max_stream_id + ADD CONSTRAINT private_user_data_max_stream_id_lock_key UNIQUE (lock); + + + +ALTER TABLE ONLY profiles + ADD CONSTRAINT profiles_user_id_key UNIQUE (user_id); + + + +ALTER TABLE ONLY push_rules_enable + ADD CONSTRAINT push_rules_enable_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY push_rules_enable + ADD CONSTRAINT push_rules_enable_user_name_rule_id_key UNIQUE (user_name, rule_id); + + + +ALTER TABLE ONLY push_rules + ADD CONSTRAINT push_rules_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY push_rules + ADD CONSTRAINT push_rules_user_name_rule_id_key UNIQUE (user_name, rule_id); + + + +ALTER TABLE ONLY pusher_throttle + ADD CONSTRAINT pusher_throttle_pkey PRIMARY KEY (pusher, room_id); + + + +ALTER TABLE ONLY pushers + ADD CONSTRAINT pushers2_app_id_pushkey_user_name_key UNIQUE (app_id, pushkey, user_name); + + + +ALTER TABLE ONLY pushers + ADD CONSTRAINT pushers2_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY receipts_graph + ADD CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id); + + + +ALTER TABLE ONLY receipts_linearized + ADD CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id); + + + +ALTER TABLE ONLY received_transactions + ADD CONSTRAINT received_transactions_transaction_id_origin_key UNIQUE (transaction_id, origin); + + + +ALTER TABLE ONLY redactions + ADD CONSTRAINT redactions_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY rejections + ADD CONSTRAINT rejections_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY remote_media_cache + ADD CONSTRAINT remote_media_cache_media_origin_media_id_key UNIQUE (media_origin, media_id); + + + +ALTER TABLE ONLY remote_media_cache_thumbnails + ADD CONSTRAINT remote_media_cache_thumbnails_media_origin_media_id_thumbna_key UNIQUE (media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type); + + + +ALTER TABLE ONLY room_account_data + ADD CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type); + + + +ALTER TABLE ONLY room_aliases + ADD CONSTRAINT room_aliases_room_alias_key UNIQUE (room_alias); + + + +ALTER TABLE ONLY room_depth + ADD CONSTRAINT room_depth_room_id_key UNIQUE (room_id); + + + +ALTER TABLE ONLY room_memberships + ADD CONSTRAINT room_memberships_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY room_names + ADD CONSTRAINT room_names_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY room_tags_revisions + ADD CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id); + + + +ALTER TABLE ONLY room_tags + ADD CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag); + + + +ALTER TABLE ONLY rooms + ADD CONSTRAINT rooms_pkey PRIMARY KEY (room_id); + + + +ALTER TABLE ONLY server_keys_json + ADD CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server); + + + +ALTER TABLE ONLY server_signature_keys + ADD CONSTRAINT server_signature_keys_server_name_key_id_key UNIQUE (server_name, key_id); + + + +ALTER TABLE ONLY state_events + ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); + + +ALTER TABLE ONLY stats_stream_pos + ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock); + + + +ALTER TABLE ONLY topics + ADD CONSTRAINT topics_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY user_directory_stream_pos + ADD CONSTRAINT user_directory_stream_pos_lock_key UNIQUE (lock); + + + +ALTER TABLE ONLY users + ADD CONSTRAINT users_name_key UNIQUE (name); + + + +CREATE INDEX access_tokens_device_id ON access_tokens USING btree (user_id, device_id); + + + +CREATE INDEX account_data_stream_id ON account_data USING btree (user_id, stream_id); + + + +CREATE INDEX application_services_txns_id ON application_services_txns USING btree (as_id); + + + +CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list USING btree (appservice_id, network_id, room_id); + + + +CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms USING btree (room_id); + + + +CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream USING btree (stream_id); + + + +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream USING btree (stream_id); + + + +CREATE INDEX current_state_events_member_index ON current_state_events USING btree (state_key) WHERE (type = 'm.room.member'::text); + + + +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers USING btree (stream_id); + + + +CREATE INDEX device_federation_inbox_sender_id ON device_federation_inbox USING btree (origin, message_id); + + + +CREATE INDEX device_federation_outbox_destination_id ON device_federation_outbox USING btree (destination, stream_id); + + + +CREATE INDEX device_federation_outbox_id ON device_federation_outbox USING btree (stream_id); + + + +CREATE INDEX device_inbox_stream_id_user_id ON device_inbox USING btree (stream_id, user_id); + + + +CREATE INDEX device_inbox_user_stream_id ON device_inbox USING btree (user_id, device_id, stream_id); + + + +CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success USING btree (destination, user_id, stream_id); + + + +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes USING btree (destination, stream_id); + + + +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes USING btree (stream_id); + + + +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes USING btree (destination, user_id); + + + +CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache USING btree (user_id, device_id); + + + +CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties USING btree (user_id); + + + +CREATE INDEX device_lists_stream_id ON device_lists_stream USING btree (stream_id, user_id); + + + +CREATE INDEX device_lists_stream_user_id ON device_lists_stream USING btree (user_id, device_id); + + + +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys USING btree (user_id, room_id, session_id); + + + +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions USING btree (user_id, version); + + + +CREATE UNIQUE INDEX erased_users_user ON erased_users USING btree (user_id); + + + +CREATE INDEX ev_b_extrem_id ON event_backward_extremities USING btree (event_id); + + + +CREATE INDEX ev_b_extrem_room ON event_backward_extremities USING btree (room_id); + + + +CREATE INDEX ev_edges_id ON event_edges USING btree (event_id); + + + +CREATE INDEX ev_edges_prev_id ON event_edges USING btree (prev_event_id); + + + +CREATE INDEX ev_extrem_id ON event_forward_extremities USING btree (event_id); + + + +CREATE INDEX ev_extrem_room ON event_forward_extremities USING btree (room_id); + + + +CREATE INDEX evauth_edges_id ON event_auth USING btree (event_id); + + + +CREATE INDEX event_contains_url_index ON events USING btree (room_id, topological_ordering, stream_ordering) WHERE ((contains_url = true) AND (outlier = false)); + + + +CREATE INDEX event_json_room_id ON event_json USING btree (room_id); + + + +CREATE INDEX event_push_actions_highlights_index ON event_push_actions USING btree (user_id, room_id, topological_ordering, stream_ordering) WHERE (highlight = 1); + + + +CREATE INDEX event_push_actions_rm_tokens ON event_push_actions USING btree (user_id, room_id, topological_ordering, stream_ordering); + + + +CREATE INDEX event_push_actions_room_id_user_id ON event_push_actions USING btree (room_id, user_id); + + + +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging USING btree (event_id); + + + +CREATE INDEX event_push_actions_stream_ordering ON event_push_actions USING btree (stream_ordering, user_id); + + + +CREATE INDEX event_push_actions_u_highlight ON event_push_actions USING btree (user_id, stream_ordering); + + + +CREATE INDEX event_push_summary_user_rm ON event_push_summary USING btree (user_id, room_id); + + + +CREATE INDEX event_reference_hashes_id ON event_reference_hashes USING btree (event_id); + + + +CREATE UNIQUE INDEX event_relations_id ON event_relations USING btree (event_id); + + + +CREATE INDEX event_relations_relates ON event_relations USING btree (relates_to_id, relation_type, aggregation_key); + + + +CREATE INDEX event_search_ev_ridx ON event_search USING btree (room_id); + + + +CREATE UNIQUE INDEX event_search_event_id_idx ON event_search USING btree (event_id); + + + +CREATE INDEX event_search_fts_idx ON event_search USING gin (vector); + + + +CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups USING btree (state_group); + + + +CREATE INDEX events_order_room ON events USING btree (room_id, topological_ordering, stream_ordering); + + + +CREATE INDEX events_room_stream ON events USING btree (room_id, stream_ordering); + + + +CREATE INDEX events_ts ON events USING btree (origin_server_ts, stream_ordering); + + + +CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote USING btree (group_id, user_id); + + + +CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote USING btree (user_id); + + + +CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote USING btree (valid_until_ms); + + + +CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals USING btree (group_id, user_id); + + + +CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals USING btree (user_id); + + + +CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals USING btree (valid_until_ms); + + + +CREATE UNIQUE INDEX group_invites_g_idx ON group_invites USING btree (group_id, user_id); + + + +CREATE INDEX group_invites_u_idx ON group_invites USING btree (user_id); + + + +CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms USING btree (group_id, room_id); + + + +CREATE INDEX group_rooms_r_idx ON group_rooms USING btree (room_id); + + + +CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms USING btree (group_id, room_id, category_id); + + + +CREATE INDEX group_summary_users_g_idx ON group_summary_users USING btree (group_id); + + + +CREATE UNIQUE INDEX group_users_g_idx ON group_users USING btree (group_id, user_id); + + + +CREATE INDEX group_users_u_idx ON group_users USING btree (user_id); + + + +CREATE UNIQUE INDEX groups_idx ON groups USING btree (group_id); + + + +CREATE INDEX local_group_membership_g_idx ON local_group_membership USING btree (group_id); + + + +CREATE INDEX local_group_membership_u_idx ON local_group_membership USING btree (user_id, group_id); + + + +CREATE INDEX local_invites_for_user_idx ON local_invites USING btree (invitee, locally_rejected, replaced_by, room_id); + + + +CREATE INDEX local_invites_id ON local_invites USING btree (stream_id); + + + +CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails USING btree (media_id); + + + +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache USING btree (url, download_ts); + + + +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache USING btree (expires_ts); + + + +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache USING btree (media_id); + + + +CREATE INDEX local_media_repository_url_idx ON local_media_repository USING btree (created_ts) WHERE (url_cache IS NOT NULL); + + + +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users USING btree ("timestamp"); + + + +CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users USING btree (user_id); + + + +CREATE INDEX open_id_tokens_ts_valid_until_ms ON open_id_tokens USING btree (ts_valid_until_ms); + + + +CREATE INDEX presence_stream_id ON presence_stream USING btree (stream_id, user_id); + + + +CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id); + + + +CREATE INDEX public_room_index ON rooms USING btree (is_public); + + + +CREATE INDEX public_room_list_stream_idx ON public_room_list_stream USING btree (stream_id); + + + +CREATE INDEX public_room_list_stream_rm_idx ON public_room_list_stream USING btree (room_id, stream_id); + + + +CREATE INDEX push_rules_enable_user_name ON push_rules_enable USING btree (user_name); + + + +CREATE INDEX push_rules_stream_id ON push_rules_stream USING btree (stream_id); + + + +CREATE INDEX push_rules_stream_user_stream_id ON push_rules_stream USING btree (user_id, stream_id); + + + +CREATE INDEX push_rules_user_name ON push_rules USING btree (user_name); + + + +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override USING btree (user_id); + + + +CREATE INDEX receipts_linearized_id ON receipts_linearized USING btree (stream_id); + + + +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized USING btree (room_id, stream_id); + + + +CREATE INDEX receipts_linearized_user ON receipts_linearized USING btree (user_id); + + + +CREATE INDEX received_transactions_ts ON received_transactions USING btree (ts); + + + +CREATE INDEX redactions_redacts ON redactions USING btree (redacts); + + + +CREATE INDEX remote_profile_cache_time ON remote_profile_cache USING btree (last_check); + + + +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache USING btree (user_id); + + + +CREATE INDEX room_account_data_stream_id ON room_account_data USING btree (user_id, stream_id); + + + +CREATE INDEX room_alias_servers_alias ON room_alias_servers USING btree (room_alias); + + + +CREATE INDEX room_aliases_id ON room_aliases USING btree (room_id); + + + +CREATE INDEX room_depth_room ON room_depth USING btree (room_id); + + + +CREATE INDEX room_memberships_room_id ON room_memberships USING btree (room_id); + + + +CREATE INDEX room_memberships_user_id ON room_memberships USING btree (user_id); + + + +CREATE INDEX room_names_room_id ON room_names USING btree (room_id); + + + +CREATE UNIQUE INDEX room_state_room ON room_state USING btree (room_id); + + + +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token USING btree (room_id); + + + +CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts); + + + +CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); + + + +CREATE INDEX stream_ordering_to_exterm_rm_idx ON stream_ordering_to_exterm USING btree (room_id, stream_ordering); + + + +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens USING btree (medium, address); + + + +CREATE INDEX topics_room_id ON topics USING btree (room_id); + + + +CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits USING btree ("timestamp"); + + + +CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits USING btree (user_id, "timestamp"); + + + +CREATE INDEX user_directory_room_idx ON user_directory USING btree (room_id); + + + +CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin (vector); + + + +CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search USING btree (user_id); + + + +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory USING btree (user_id); + + + +CREATE INDEX user_filters_by_user_id_filter_id ON user_filters USING btree (user_id, filter_id); + + + +CREATE INDEX user_ips_device_id ON user_ips USING btree (user_id, device_id, last_seen); + + + +CREATE INDEX user_ips_last_seen ON user_ips USING btree (user_id, last_seen); + + + +CREATE INDEX user_ips_last_seen_only ON user_ips USING btree (last_seen); + + + +CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips USING btree (user_id, access_token, ip); + + + +CREATE UNIQUE INDEX user_stats_user_ts ON user_stats USING btree (user_id, ts); + + + +CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server USING btree (user_id, medium, address, id_server); + + + +CREATE INDEX user_threepids_medium_address ON user_threepids USING btree (medium, address); + + + +CREATE INDEX user_threepids_user_id ON user_threepids USING btree (user_id); + + + +CREATE INDEX users_creation_ts ON users USING btree (creation_ts); + + + +CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms USING btree (user_id, room_id); + + + +CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms USING btree (other_user_id); + + + +CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms USING btree (room_id); + + + +CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id); diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite new file mode 100644 index 0000000000..a0411ede7e --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite @@ -0,0 +1,253 @@ +CREATE TABLE application_services_state( as_id TEXT PRIMARY KEY, state VARCHAR(5), last_txn INTEGER ); +CREATE TABLE application_services_txns( as_id TEXT NOT NULL, txn_id INTEGER NOT NULL, event_ids TEXT NOT NULL, UNIQUE(as_id, txn_id) ); +CREATE INDEX application_services_txns_id ON application_services_txns ( as_id ); +CREATE TABLE presence( user_id TEXT NOT NULL, state VARCHAR(20), status_msg TEXT, mtime BIGINT, UNIQUE (user_id) ); +CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_user_id TEXT NOT NULL, UNIQUE (observed_user_id, observer_user_id) ); +CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) ); +CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) ); +CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL ); +CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) ); +CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) ); +CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER ); +CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) ); +CREATE INDEX events_order_room ON events ( room_id, topological_ordering, stream_ordering ); +CREATE TABLE event_json( event_id TEXT NOT NULL, room_id TEXT NOT NULL, internal_metadata TEXT NOT NULL, json TEXT NOT NULL, format_version INTEGER, UNIQUE (event_id) ); +CREATE INDEX event_json_room_id ON event_json(room_id); +CREATE TABLE state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, prev_state TEXT, UNIQUE (event_id) ); +CREATE TABLE current_state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, UNIQUE (event_id), UNIQUE (room_id, type, state_key) ); +CREATE TABLE room_memberships( event_id TEXT NOT NULL, user_id TEXT NOT NULL, sender TEXT NOT NULL, room_id TEXT NOT NULL, membership TEXT NOT NULL, forgotten INTEGER DEFAULT 0, display_name TEXT, avatar_url TEXT, UNIQUE (event_id) ); +CREATE INDEX room_memberships_room_id ON room_memberships (room_id); +CREATE INDEX room_memberships_user_id ON room_memberships (user_id); +CREATE TABLE topics( event_id TEXT NOT NULL, room_id TEXT NOT NULL, topic TEXT NOT NULL, UNIQUE (event_id) ); +CREATE INDEX topics_room_id ON topics(room_id); +CREATE TABLE room_names( event_id TEXT NOT NULL, room_id TEXT NOT NULL, name TEXT NOT NULL, UNIQUE (event_id) ); +CREATE INDEX room_names_room_id ON room_names(room_id); +CREATE TABLE rooms( room_id TEXT PRIMARY KEY NOT NULL, is_public BOOL, creator TEXT ); +CREATE TABLE server_signature_keys( server_name TEXT, key_id TEXT, from_server TEXT, ts_added_ms BIGINT, verify_key bytea, ts_valid_until_ms BIGINT, UNIQUE (server_name, key_id) ); +CREATE TABLE rejections( event_id TEXT NOT NULL, reason TEXT NOT NULL, last_check TEXT NOT NULL, UNIQUE (event_id) ); +CREATE TABLE push_rules ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, rule_id TEXT NOT NULL, priority_class SMALLINT NOT NULL, priority INTEGER NOT NULL DEFAULT 0, conditions TEXT NOT NULL, actions TEXT NOT NULL, UNIQUE(user_name, rule_id) ); +CREATE INDEX push_rules_user_name on push_rules (user_name); +CREATE TABLE user_filters( user_id TEXT, filter_id BIGINT, filter_json bytea ); +CREATE INDEX user_filters_by_user_id_filter_id ON user_filters( user_id, filter_id ); +CREATE TABLE push_rules_enable ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, rule_id TEXT NOT NULL, enabled SMALLINT, UNIQUE(user_name, rule_id) ); +CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name); +CREATE TABLE event_forward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, UNIQUE (event_id, room_id) ); +CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); +CREATE TABLE event_backward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, UNIQUE (event_id, room_id) ); +CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); +CREATE TABLE event_edges( event_id TEXT NOT NULL, prev_event_id TEXT NOT NULL, room_id TEXT NOT NULL, is_state BOOL NOT NULL, UNIQUE (event_id, prev_event_id, room_id, is_state) ); +CREATE INDEX ev_edges_id ON event_edges(event_id); +CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); +CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); +CREATE INDEX room_depth_room ON room_depth(room_id); +CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); +CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) ); +CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); +CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id); +CREATE TABLE remote_media_cache ( media_origin TEXT, media_id TEXT, media_type TEXT, created_ts BIGINT, upload_name TEXT, media_length INTEGER, filesystem_id TEXT, last_access_ts BIGINT, quarantined_by TEXT, UNIQUE (media_origin, media_id) ); +CREATE TABLE remote_media_cache_thumbnails ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); +CREATE TABLE redactions ( event_id TEXT NOT NULL, redacts TEXT NOT NULL, UNIQUE (event_id) ); +CREATE INDEX redactions_redacts ON redactions (redacts); +CREATE TABLE room_aliases( room_alias TEXT NOT NULL, room_id TEXT NOT NULL, creator TEXT, UNIQUE (room_alias) ); +CREATE INDEX room_aliases_id ON room_aliases(room_id); +CREATE TABLE room_alias_servers( room_alias TEXT NOT NULL, server TEXT NOT NULL ); +CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias); +CREATE TABLE event_reference_hashes ( event_id TEXT, algorithm TEXT, hash bytea, UNIQUE (event_id, algorithm) ); +CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); +CREATE TABLE IF NOT EXISTS "server_keys_json" ( server_name TEXT NOT NULL, key_id TEXT NOT NULL, from_server TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, ts_valid_until_ms BIGINT NOT NULL, key_json bytea NOT NULL, CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server) ); +CREATE TABLE e2e_device_keys_json ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, key_json TEXT NOT NULL, CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) ); +CREATE TABLE e2e_one_time_keys_json ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, algorithm TEXT NOT NULL, key_id TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, key_json TEXT NOT NULL, CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) ); +CREATE TABLE receipts_graph( room_id TEXT NOT NULL, receipt_type TEXT NOT NULL, user_id TEXT NOT NULL, event_ids TEXT NOT NULL, data TEXT NOT NULL, CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id) ); +CREATE TABLE receipts_linearized ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, receipt_type TEXT NOT NULL, user_id TEXT NOT NULL, event_id TEXT NOT NULL, data TEXT NOT NULL, CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id) ); +CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); +CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) ); +CREATE INDEX user_threepids_user_id ON user_threepids(user_id); +CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value ) +/* event_search(event_id,room_id,sender,"key",value) */; +CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value'); +CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); +CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); +CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); +CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB); +CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) ); +CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) ); +CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) ); +CREATE TABLE room_tags_revisions ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, stream_id BIGINT NOT NULL, CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) ); +CREATE TABLE IF NOT EXISTS "account_data_max_stream_id"( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT NOT NULL, CHECK (Lock='X') ); +CREATE TABLE account_data( user_id TEXT NOT NULL, account_data_type TEXT NOT NULL, stream_id BIGINT NOT NULL, content TEXT NOT NULL, CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) ); +CREATE TABLE room_account_data( user_id TEXT NOT NULL, room_id TEXT NOT NULL, account_data_type TEXT NOT NULL, stream_id BIGINT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) ); +CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); +CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); +CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering); +CREATE TABLE event_push_actions( room_id TEXT NOT NULL, event_id TEXT NOT NULL, user_id TEXT NOT NULL, profile_tag VARCHAR(32), actions TEXT NOT NULL, topological_ordering BIGINT, stream_ordering BIGINT, notif SMALLINT, highlight SMALLINT, CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) ); +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); +CREATE INDEX events_room_stream on events(room_id, stream_ordering); +CREATE INDEX public_room_index on rooms(is_public); +CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); +CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering ); +CREATE TABLE presence_stream( stream_id BIGINT, user_id TEXT, state TEXT, last_active_ts BIGINT, last_federation_update_ts BIGINT, last_user_sync_ts BIGINT, status_msg TEXT, currently_active BOOLEAN ); +CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); +CREATE INDEX presence_stream_user_id ON presence_stream(user_id); +CREATE TABLE push_rules_stream( stream_id BIGINT NOT NULL, event_stream_ordering BIGINT NOT NULL, user_id TEXT NOT NULL, rule_id TEXT NOT NULL, op TEXT NOT NULL, priority_class SMALLINT, priority INTEGER, conditions TEXT, actions TEXT ); +CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); +CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); +CREATE TABLE ex_outlier_stream( event_stream_ordering BIGINT PRIMARY KEY NOT NULL, event_id TEXT NOT NULL, state_group BIGINT NOT NULL ); +CREATE TABLE threepid_guest_access_tokens( medium TEXT, address TEXT, guest_access_token TEXT, first_inviter TEXT ); +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); +CREATE TABLE local_invites( stream_id BIGINT NOT NULL, inviter TEXT NOT NULL, invitee TEXT NOT NULL, event_id TEXT NOT NULL, room_id TEXT NOT NULL, locally_rejected TEXT, replaced_by TEXT ); +CREATE INDEX local_invites_id ON local_invites(stream_id); +CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); +CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id ); +CREATE TABLE open_id_tokens ( token TEXT NOT NULL PRIMARY KEY, ts_valid_until_ms bigint NOT NULL, user_id TEXT NOT NULL, UNIQUE (token) ); +CREATE INDEX open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms); +CREATE TABLE pusher_throttle( pusher BIGINT NOT NULL, room_id TEXT NOT NULL, last_sent_ts BIGINT, throttle_ms BIGINT, PRIMARY KEY (pusher, room_id) ); +CREATE TABLE event_reports( id BIGINT NOT NULL PRIMARY KEY, received_ts BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, user_id TEXT NOT NULL, reason TEXT, content TEXT ); +CREATE TABLE devices ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, display_name TEXT, CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) ); +CREATE TABLE appservice_stream_position( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_ordering BIGINT, CHECK (Lock='X') ); +CREATE TABLE device_inbox ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, stream_id BIGINT NOT NULL, message_json TEXT NOT NULL ); +CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); +CREATE INDEX received_transactions_ts ON received_transactions(ts); +CREATE TABLE device_federation_outbox ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, queued_ts BIGINT NOT NULL, messages_json TEXT NOT NULL ); +CREATE INDEX device_federation_outbox_destination_id ON device_federation_outbox(destination, stream_id); +CREATE TABLE device_federation_inbox ( origin TEXT NOT NULL, message_id TEXT NOT NULL, received_ts BIGINT NOT NULL ); +CREATE INDEX device_federation_inbox_sender_id ON device_federation_inbox(origin, message_id); +CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL ); +CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT); +CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id ); +CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id ); +CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); +CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); +CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); +CREATE TABLE IF NOT EXISTS "event_auth"( event_id TEXT NOT NULL, auth_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE INDEX evauth_edges_id ON event_auth(event_id); +CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); +CREATE TABLE appservice_room_list( appservice_id TEXT NOT NULL, network_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( appservice_id, network_id, room_id ); +CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); +CREATE TABLE federation_stream_position( type TEXT NOT NULL, stream_id INTEGER NOT NULL ); +CREATE TABLE device_lists_remote_cache ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, content TEXT NOT NULL ); +CREATE TABLE device_lists_remote_extremeties ( user_id TEXT NOT NULL, stream_id TEXT NOT NULL ); +CREATE TABLE device_lists_stream ( stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL ); +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); +CREATE TABLE device_lists_outbound_pokes ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL, sent BOOLEAN NOT NULL, ts BIGINT NOT NULL ); +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); +CREATE TABLE event_push_summary ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, notif_count BIGINT NOT NULL, stream_ordering BIGINT NOT NULL ); +CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); +CREATE TABLE event_push_summary_stream_ordering ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_ordering BIGINT NOT NULL, CHECK (Lock='X') ); +CREATE TABLE IF NOT EXISTS "pushers" ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, access_token BIGINT DEFAULT NULL, profile_tag TEXT NOT NULL, kind TEXT NOT NULL, app_id TEXT NOT NULL, app_display_name TEXT NOT NULL, device_display_name TEXT NOT NULL, pushkey TEXT NOT NULL, ts BIGINT NOT NULL, lang TEXT, data TEXT, last_stream_ordering INTEGER, last_success BIGINT, failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ); +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); +CREATE TABLE ratelimit_override ( user_id TEXT NOT NULL, messages_per_second BIGINT, burst_count BIGINT ); +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); +CREATE TABLE current_state_delta_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT, prev_event_id TEXT ); +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); +CREATE TABLE device_lists_outbound_last_success ( destination TEXT NOT NULL, user_id TEXT NOT NULL, stream_id BIGINT NOT NULL ); +CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( destination, user_id, stream_id ); +CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); +CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value ) +/* user_directory_search(user_id,value) */; +CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value'); +CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); +CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); +CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); +CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB); +CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL ); +CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); +CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT ); +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); +CREATE TABLE group_users ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, is_admin BOOLEAN NOT NULL, is_public BOOLEAN NOT NULL ); +CREATE TABLE group_invites ( group_id TEXT NOT NULL, user_id TEXT NOT NULL ); +CREATE TABLE group_rooms ( group_id TEXT NOT NULL, room_id TEXT NOT NULL, is_public BOOLEAN NOT NULL ); +CREATE TABLE group_summary_rooms ( group_id TEXT NOT NULL, room_id TEXT NOT NULL, category_id TEXT NOT NULL, room_order BIGINT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, category_id, room_id, room_order), CHECK (room_order > 0) ); +CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); +CREATE TABLE group_summary_room_categories ( group_id TEXT NOT NULL, category_id TEXT NOT NULL, cat_order BIGINT NOT NULL, UNIQUE (group_id, category_id, cat_order), CHECK (cat_order > 0) ); +CREATE TABLE group_room_categories ( group_id TEXT NOT NULL, category_id TEXT NOT NULL, profile TEXT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, category_id) ); +CREATE TABLE group_summary_users ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, role_id TEXT NOT NULL, user_order BIGINT NOT NULL, is_public BOOLEAN NOT NULL ); +CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); +CREATE TABLE group_summary_roles ( group_id TEXT NOT NULL, role_id TEXT NOT NULL, role_order BIGINT NOT NULL, UNIQUE (group_id, role_id, role_order), CHECK (role_order > 0) ); +CREATE TABLE group_roles ( group_id TEXT NOT NULL, role_id TEXT NOT NULL, profile TEXT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, role_id) ); +CREATE TABLE group_attestations_renewals ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, valid_until_ms BIGINT NOT NULL ); +CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); +CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); +CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); +CREATE TABLE group_attestations_remote ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, valid_until_ms BIGINT NOT NULL, attestation_json TEXT NOT NULL ); +CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); +CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); +CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); +CREATE TABLE local_group_membership ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, is_admin BOOLEAN NOT NULL, membership TEXT NOT NULL, is_publicised BOOLEAN NOT NULL, content TEXT NOT NULL ); +CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); +CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); +CREATE TABLE local_group_updates ( stream_id BIGINT NOT NULL, group_id TEXT NOT NULL, user_id TEXT NOT NULL, type TEXT NOT NULL, content TEXT NOT NULL ); +CREATE TABLE remote_profile_cache ( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, last_check BIGINT NOT NULL ); +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); +CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); +CREATE TABLE IF NOT EXISTS "deleted_pushers" ( stream_id BIGINT NOT NULL, app_id TEXT NOT NULL, pushkey TEXT NOT NULL, user_id TEXT NOT NULL ); +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); +CREATE TABLE IF NOT EXISTS "groups" ( group_id TEXT NOT NULL, name TEXT, avatar_url TEXT, short_description TEXT, long_description TEXT, is_public BOOL NOT NULL , join_policy TEXT NOT NULL DEFAULT 'invite'); +CREATE UNIQUE INDEX groups_idx ON groups(group_id); +CREATE TABLE IF NOT EXISTS "user_directory" ( user_id TEXT NOT NULL, room_id TEXT, display_name TEXT, avatar_url TEXT ); +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); +CREATE TABLE event_push_actions_staging ( event_id TEXT NOT NULL, user_id TEXT NOT NULL, actions TEXT NOT NULL, notif SMALLINT NOT NULL, highlight SMALLINT NOT NULL ); +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); +CREATE TABLE users_pending_deactivation ( user_id TEXT NOT NULL ); +CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); +CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); +CREATE INDEX group_users_u_idx ON group_users(user_id); +CREATE INDEX group_invites_u_idx ON group_invites(user_id); +CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); +CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL ); +CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); +CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); +CREATE TABLE erased_users ( user_id TEXT NOT NULL ); +CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id); +CREATE TABLE monthly_active_users ( user_id TEXT NOT NULL, timestamp BIGINT NOT NULL ); +CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); +CREATE TABLE IF NOT EXISTS "e2e_room_keys_versions" ( user_id TEXT NOT NULL, version BIGINT NOT NULL, algorithm TEXT NOT NULL, auth_data TEXT NOT NULL, deleted SMALLINT DEFAULT 0 NOT NULL ); +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); +CREATE TABLE IF NOT EXISTS "e2e_room_keys" ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, version BIGINT NOT NULL, first_message_index INT, forwarded_count INT, is_verified BOOLEAN, session_data TEXT NOT NULL ); +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); +CREATE TABLE users_who_share_private_rooms ( user_id TEXT NOT NULL, other_user_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id); +CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id); +CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id); +CREATE TABLE user_threepid_id_server ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, id_server TEXT NOT NULL ); +CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server( user_id, medium, address, id_server ); +CREATE TABLE users_in_public_rooms ( user_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id); +CREATE TABLE account_validity ( user_id TEXT PRIMARY KEY, expiration_ts_ms BIGINT NOT NULL, email_sent BOOLEAN NOT NULL, renewal_token TEXT ); +CREATE TABLE event_relations ( event_id TEXT NOT NULL, relates_to_id TEXT NOT NULL, relation_type TEXT NOT NULL, aggregation_key TEXT ); +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); +CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); +CREATE TABLE stats_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); +CREATE TABLE user_stats ( user_id TEXT NOT NULL, ts BIGINT NOT NULL, bucket_size INT NOT NULL, public_rooms INT NOT NULL, private_rooms INT NOT NULL ); +CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); +CREATE TABLE room_stats ( room_id TEXT NOT NULL, ts BIGINT NOT NULL, bucket_size INT NOT NULL, current_state_events INT NOT NULL, joined_members INT NOT NULL, invited_members INT NOT NULL, left_members INT NOT NULL, banned_members INT NOT NULL, state_events INT NOT NULL ); +CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); +CREATE TABLE room_state ( room_id TEXT NOT NULL, join_rules TEXT, history_visibility TEXT, encryption TEXT, name TEXT, topic TEXT, avatar TEXT, canonical_alias TEXT ); +CREATE UNIQUE INDEX room_state_room ON room_state(room_id); +CREATE TABLE room_stats_earliest_token ( room_id TEXT NOT NULL, token BIGINT NOT NULL ); +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); +CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id); +CREATE INDEX user_ips_device_id ON user_ips (user_id, device_id, last_seen); +CREATE INDEX event_contains_url_index ON events (room_id, topological_ordering, stream_ordering); +CREATE INDEX event_push_actions_u_highlight ON event_push_actions (user_id, stream_ordering); +CREATE INDEX event_push_actions_highlights_index ON event_push_actions (user_id, room_id, topological_ordering, stream_ordering); +CREATE INDEX current_state_events_member_index ON current_state_events (state_key); +CREATE INDEX device_inbox_stream_id_user_id ON device_inbox (stream_id, user_id); +CREATE INDEX device_lists_stream_user_id ON device_lists_stream (user_id, device_id); +CREATE INDEX local_media_repository_url_idx ON local_media_repository (created_ts); +CREATE INDEX user_ips_last_seen ON user_ips (user_id, last_seen); +CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); +CREATE INDEX users_creation_ts ON users (creation_ts); +CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); +CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); +CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); +CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); diff --git a/synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql new file mode 100644 index 0000000000..91d21b2921 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql @@ -0,0 +1,8 @@ + +INSERT INTO appservice_stream_position (stream_ordering) SELECT COALESCE(MAX(stream_ordering), 0) FROM events; +INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); +INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; +INSERT INTO user_directory_stream_pos (stream_id) VALUES (0); +INSERT INTO stats_stream_pos (stream_id) VALUES (0); +INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); +-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/full_schemas/README.md b/synapse/storage/databases/main/schema/full_schemas/README.md new file mode 100644 index 0000000000..c00f287190 --- /dev/null +++ b/synapse/storage/databases/main/schema/full_schemas/README.md @@ -0,0 +1,21 @@ +# Synapse Database Schemas + +These schemas are used as a basis to create brand new Synapse databases, on both +SQLite3 and Postgres. + +## Building full schema dumps + +If you want to recreate these schemas, they need to be made from a database that +has had all background updates run. + +To do so, use `scripts-dev/make_full_schema.sh`. This will produce new +`full.sql.postgres ` and `full.sql.sqlite` files. + +Ensure postgres is installed and your user has the ability to run bash commands +such as `createdb`, then call + + ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/ + +There are currently two folders with full-schema snapshots. `16` is a snapshot +from 2015, for historical reference. The other contains the most recent full +schema snapshot. diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py new file mode 100644 index 0000000000..2162d0712d --- /dev/null +++ b/synapse/storage/databases/main/search.py @@ -0,0 +1,710 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from collections import namedtuple + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +logger = logging.getLogger(__name__) + +SearchEntry = namedtuple( + "SearchEntry", + ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], +) + + +class SearchWorkerStore(SQLBaseStore): + def store_search_entries_txn(self, txn, entries): + """Add entries to the search table + + Args: + txn (cursor): + entries (iterable[SearchEntry]): + entries to be added to the table + """ + if not self.hs.config.enable_search: + return + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + + args = ( + ( + entry.event_id, + entry.room_id, + entry.key, + entry.value, + entry.stream_ordering, + entry.origin_server_ts, + ) + for entry in entries + ) + + txn.executemany(sql, args) + + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + args = ( + (entry.event_id, entry.room_id, entry.key, entry.value) + for entry in entries + ) + + txn.executemany(sql, args) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + +class SearchBackgroundUpdateStore(SearchWorkerStore): + + EVENT_SEARCH_UPDATE_NAME = "event_search" + EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" + EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" + EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + if not hs.config.enable_search: + return + + self.db_pool.updates.register_background_update_handler( + self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search + ) + self.db_pool.updates.register_background_update_handler( + self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order + ) + + # we used to have a background update to turn the GIN index into a + # GIST one; we no longer do that (obviously) because we actually want + # a GIN index. However, it's possible that some people might still have + # the background update queued, so we register a handler to clear the + # background update. + self.db_pool.updates.register_noop_background_update( + self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME + ) + + self.db_pool.updates.register_background_update_handler( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search + ) + + @defer.inlineCallbacks + def _background_reindex_search(self, progress, batch_size): + # we work through the events table from highest stream id to lowest + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + TYPES = ["m.room.name", "m.room.message", "m.room.topic"] + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_ordering, event_id, room_id, type, json, " + " origin_server_ts FROM events" + " JOIN event_json USING (room_id, event_id)" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " AND (%s)" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + # we could stream straight from the results into + # store_search_entries_txn with a generator function, but that + # would mean having two cursors open on the database at once. + # Instead we just build a list of results. + rows = self.db_pool.cursor_to_dict(txn) + if not rows: + return 0 + + min_stream_id = rows[-1]["stream_ordering"] + + event_search_rows = [] + for row in rows: + try: + event_id = row["event_id"] + room_id = row["room_id"] + etype = row["type"] + stream_ordering = row["stream_ordering"] + origin_server_ts = row["origin_server_ts"] + try: + event_json = db_to_json(row["json"]) + content = event_json["content"] + except Exception: + continue + + if etype == "m.room.message": + key = "content.body" + value = content["body"] + elif etype == "m.room.topic": + key = "content.topic" + value = content["topic"] + elif etype == "m.room.name": + key = "content.name" + value = content["name"] + else: + raise Exception("unexpected event type %s" % etype) + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + if not isinstance(value, str): + # If the event body, name or topic isn't a string + # then skip over it + continue + + event_search_rows.append( + SearchEntry( + key=key, + value=value, + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + origin_server_ts=origin_server_ts, + ) + ) + + self.store_search_entries_txn(txn, event_search_rows) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(event_search_rows), + } + + self.db_pool.updates._background_update_progress_txn( + txn, self.EVENT_SEARCH_UPDATE_NAME, progress + ) + + return len(event_search_rows) + + result = yield self.db_pool.runInteraction( + self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn + ) + + if not result: + yield self.db_pool.updates._end_background_update( + self.EVENT_SEARCH_UPDATE_NAME + ) + + return result + + @defer.inlineCallbacks + def _background_reindex_gin_search(self, progress, batch_size): + """This handles old synapses which used GIST indexes, if any; + converting them back to be GIN as per the actual schema. + """ + + def create_index(conn): + conn.rollback() + + # we have to set autocommit, because postgres refuses to + # CREATE INDEX CONCURRENTLY without it. + conn.set_session(autocommit=True) + + try: + c = conn.cursor() + + # if we skipped the conversion to GIST, we may already/still + # have an event_search_fts_idx; unfortunately postgres 9.4 + # doesn't support CREATE INDEX IF EXISTS so we just catch the + # exception and ignore it. + import psycopg2 + + try: + c.execute( + "CREATE INDEX CONCURRENTLY event_search_fts_idx" + " ON event_search USING GIN (vector)" + ) + except psycopg2.ProgrammingError as e: + logger.warning( + "Ignoring error %r when trying to switch from GIST to GIN", e + ) + + # we should now be able to delete the GIST index. + c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist") + finally: + conn.set_session(autocommit=False) + + if isinstance(self.database_engine, PostgresEngine): + yield self.db_pool.runWithConnection(create_index) + + yield self.db_pool.updates._end_background_update( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME + ) + return 1 + + @defer.inlineCallbacks + def _background_reindex_search_order(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + have_added_index = progress["have_added_indexes"] + + if not have_added_index: + + def create_index(conn): + conn.rollback() + conn.set_session(autocommit=True) + c = conn.cursor() + + # We create with NULLS FIRST so that when we search *backwards* + # we get the ones with non null origin_server_ts *first* + c.execute( + "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search(" + "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + ) + c.execute( + "CREATE INDEX CONCURRENTLY event_search_order ON event_search(" + "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + ) + conn.set_session(autocommit=False) + + yield self.db_pool.runWithConnection(create_index) + + pg = dict(progress) + pg["have_added_indexes"] = True + + yield self.db_pool.runInteraction( + self.EVENT_SEARCH_ORDER_UPDATE_NAME, + self.db_pool.updates._background_update_progress_txn, + self.EVENT_SEARCH_ORDER_UPDATE_NAME, + pg, + ) + + def reindex_search_txn(txn): + sql = ( + "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," + " origin_server_ts = e.origin_server_ts" + " FROM events AS e" + " WHERE e.event_id = es.event_id" + " AND ? <= e.stream_ordering AND e.stream_ordering < ?" + " RETURNING es.stream_ordering" + ) + + min_stream_id = max_stream_id - batch_size + txn.execute(sql, (min_stream_id, max_stream_id)) + rows = txn.fetchall() + + if min_stream_id < target_min_stream_id: + # We've recached the end. + return len(rows), False + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows), + "have_added_indexes": True, + } + + self.db_pool.updates._background_update_progress_txn( + txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress + ) + + return len(rows), True + + num_rows, finished = yield self.db_pool.runInteraction( + self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn + ) + + if not finished: + yield self.db_pool.updates._end_background_update( + self.EVENT_SEARCH_ORDER_UPDATE_NAME + ) + + return num_rows + + +class SearchStore(SearchBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(SearchStore, self).__init__(database, db_conn, hs) + + @defer.inlineCallbacks + def search_msgs(self, room_ids, search_term, keys): + """Performs a full text search over events with given keys. + + Args: + room_ids (list): List of room ids to search in + search_term (str): Search term to search for + keys (list): List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + + Returns: + list of dicts + """ + clauses = [] + + search_query = _parse_query(self.database_engine, search_term) + + args = [] + + # Make sure we don't explode because the person is in too many rooms. + # We filter the results below regardless. + if len(room_ids) < 500: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + clauses = [clause] + + local_clauses = [] + for key in keys: + local_clauses.append("key = ?") + args.append(key) + + clauses.append("(%s)" % (" OR ".join(local_clauses),)) + + count_args = args + count_clauses = clauses + + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," + " room_id, event_id" + " FROM event_search" + " WHERE vector @@ to_tsquery('english', ?)" + ) + args = [search_query, search_query] + args + + count_sql = ( + "SELECT room_id, count(*) as count FROM event_search" + " WHERE vector @@ to_tsquery('english', ?)" + ) + count_args = [search_query] + count_args + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" + " FROM event_search" + " WHERE value MATCH ?" + ) + args = [search_query] + args + + count_sql = ( + "SELECT room_id, count(*) as count FROM event_search" + " WHERE value MATCH ?" + ) + count_args = [search_term] + count_args + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + for clause in clauses: + sql += " AND " + clause + + for clause in count_clauses: + count_sql += " AND " + clause + + # We add an arbitrary limit here to ensure we don't try to pull the + # entire table from the database. + sql += " ORDER BY rank DESC LIMIT 500" + + results = yield self.db_pool.execute( + "search_msgs", self.db_pool.cursor_to_dict, sql, *args + ) + + results = list(filter(lambda row: row["room_id"] in room_ids, results)) + + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) + + event_map = {ev.event_id: ev for ev in events} + + highlights = None + if isinstance(self.database_engine, PostgresEngine): + highlights = yield self._find_highlights_in_postgres(search_query, events) + + count_sql += " GROUP BY room_id" + + count_results = yield self.db_pool.execute( + "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args + ) + + count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) + + return { + "results": [ + {"event": event_map[r["event_id"]], "rank": r["rank"]} + for r in results + if r["event_id"] in event_map + ], + "highlights": highlights, + "count": count, + } + + @defer.inlineCallbacks + def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): + """Performs a full text search over events with given keys. + + Args: + room_id (list): The room_ids to search in + search_term (str): Search term to search for + keys (list): List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + pagination_token (str): A pagination token previously returned + + Returns: + list of dicts + """ + clauses = [] + + search_query = _parse_query(self.database_engine, search_term) + + args = [] + + # Make sure we don't explode because the person is in too many rooms. + # We filter the results below regardless. + if len(room_ids) < 500: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + clauses = [clause] + + local_clauses = [] + for key in keys: + local_clauses.append("key = ?") + args.append(key) + + clauses.append("(%s)" % (" OR ".join(local_clauses),)) + + # take copies of the current args and clauses lists, before adding + # pagination clauses to main query. + count_args = list(args) + count_clauses = list(clauses) + + if pagination_token: + try: + origin_server_ts, stream = pagination_token.split(",") + origin_server_ts = int(origin_server_ts) + stream = int(stream) + except Exception: + raise SynapseError(400, "Invalid pagination token") + + clauses.append( + "(origin_server_ts < ?" + " OR (origin_server_ts = ? AND stream_ordering < ?))" + ) + args.extend([origin_server_ts, origin_server_ts, stream]) + + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," + " origin_server_ts, stream_ordering, room_id, event_id" + " FROM event_search" + " WHERE vector @@ to_tsquery('english', ?) AND " + ) + args = [search_query, search_query] + args + + count_sql = ( + "SELECT room_id, count(*) as count FROM event_search" + " WHERE vector @@ to_tsquery('english', ?) AND " + ) + count_args = [search_query] + count_args + elif isinstance(self.database_engine, Sqlite3Engine): + # We use CROSS JOIN here to ensure we use the right indexes. + # https://sqlite.org/optoverview.html#crossjoin + # + # We want to use the full text search index on event_search to + # extract all possible matches first, then lookup those matches + # in the events table to get the topological ordering. We need + # to use the indexes in this order because sqlite refuses to + # MATCH unless it uses the full text search index + sql = ( + "SELECT rank(matchinfo) as rank, room_id, event_id," + " origin_server_ts, stream_ordering" + " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" + " FROM event_search" + " WHERE value MATCH ?" + " )" + " CROSS JOIN events USING (event_id)" + " WHERE " + ) + args = [search_query] + args + + count_sql = ( + "SELECT room_id, count(*) as count FROM event_search" + " WHERE value MATCH ? AND " + ) + count_args = [search_term] + count_args + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + sql += " AND ".join(clauses) + count_sql += " AND ".join(count_clauses) + + # We add an arbitrary limit here to ensure we don't try to pull the + # entire table from the database. + if isinstance(self.database_engine, PostgresEngine): + sql += ( + " ORDER BY origin_server_ts DESC NULLS LAST," + " stream_ordering DESC NULLS LAST LIMIT ?" + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" + else: + raise Exception("Unrecognized database engine") + + args.append(limit) + + results = yield self.db_pool.execute( + "search_rooms", self.db_pool.cursor_to_dict, sql, *args + ) + + results = list(filter(lambda row: row["room_id"] in room_ids, results)) + + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) + + event_map = {ev.event_id: ev for ev in events} + + highlights = None + if isinstance(self.database_engine, PostgresEngine): + highlights = yield self._find_highlights_in_postgres(search_query, events) + + count_sql += " GROUP BY room_id" + + count_results = yield self.db_pool.execute( + "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args + ) + + count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) + + return { + "results": [ + { + "event": event_map[r["event_id"]], + "rank": r["rank"], + "pagination_token": "%s,%s" + % (r["origin_server_ts"], r["stream_ordering"]), + } + for r in results + if r["event_id"] in event_map + ], + "highlights": highlights, + "count": count, + } + + def _find_highlights_in_postgres(self, search_query, events): + """Given a list of events and a search term, return a list of words + that match from the content of the event. + + This is used to give a list of words that clients can match against to + highlight the matching parts. + + Args: + search_query (str) + events (list): A list of events + + Returns: + deferred : A set of strings. + """ + + def f(txn): + highlight_words = set() + for event in events: + # As a hack we simply join values of all possible keys. This is + # fine since we're only using them to find possible highlights. + values = [] + for key in ("body", "name", "topic"): + v = event.content.get(key, None) + if v: + values.append(v) + + if not values: + continue + + value = " ".join(values) + + # We need to find some values for StartSel and StopSel that + # aren't in the value so that we can pick results out. + start_sel = "<" + stop_sel = ">" + + while start_sel in value: + start_sel += "<" + while stop_sel in value: + stop_sel += ">" + + query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( + _to_postgres_options( + { + "StartSel": start_sel, + "StopSel": stop_sel, + "MaxFragments": "50", + } + ) + ) + txn.execute(query, (value, search_query)) + (headline,) = txn.fetchall()[0] + + # Now we need to pick the possible highlights out of the haedline + # result. + matcher_regex = "%s(.*?)%s" % ( + re.escape(start_sel), + re.escape(stop_sel), + ) + + res = re.findall(matcher_regex, headline) + highlight_words.update([r.lower() for r in res]) + + return highlight_words + + return self.db_pool.runInteraction("_find_highlights", f) + + +def _to_postgres_options(options_dict): + return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) + + +def _parse_query(database_engine, search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + if isinstance(database_engine, PostgresEngine): + return " & ".join(result + ":*" for result in results) + elif isinstance(database_engine, Sqlite3Engine): + return " & ".join(result + "*" for result in results) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py new file mode 100644 index 0000000000..dae8e8bd29 --- /dev/null +++ b/synapse/storage/databases/main/signatures.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unpaddedbase64 import encode_base64 + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedList + + +class SignatureWorkerStore(SQLBaseStore): + @cached() + def get_event_reference_hash(self, event_id): + # This is a dummy function to allow get_event_reference_hashes + # to use its cache + raise NotImplementedError() + + @cachedList( + cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 + ) + def get_event_reference_hashes(self, event_ids): + def f(txn): + return { + event_id: self._get_event_reference_hashes_txn(txn, event_id) + for event_id in event_ids + } + + return self.db_pool.runInteraction("get_event_reference_hashes", f) + + @defer.inlineCallbacks + def add_event_hashes(self, event_ids): + hashes = yield self.get_event_reference_hashes(event_ids) + hashes = { + e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} + for e_id, h in hashes.items() + } + + return list(hashes.items()) + + def _get_event_reference_hashes_txn(self, txn, event_id): + """Get all the hashes for a given PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict[unicode, bytes] of algorithm -> hash. + """ + query = ( + "SELECT algorithm, hash" + " FROM event_reference_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id,)) + return {k: v for k, v in txn} + + +class SignatureStore(SignatureWorkerStore): + """Persistence for event signatures and hashes""" diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py new file mode 100644 index 0000000000..96e0378e50 --- /dev/null +++ b/synapse/storage/databases/main/state.py @@ -0,0 +1,509 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections.abc +import logging +from collections import namedtuple +from typing import Iterable, Optional, Set + +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events import EventBase +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.state import StateFilter +from synapse.util.caches import intern_string +from synapse.util.caches.descriptors import cached, cachedList + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +# this inherits from EventsWorkerStore because it calls self.get_events +class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): + """The parts of StateGroupStore that can be called from workers. + """ + + def __init__(self, database: DatabasePool, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) + + async def get_room_version(self, room_id: str) -> RoomVersion: + """Get the room_version of a given room + + Raises: + NotFoundError: if the room is unknown + + UnsupportedRoomVersionError: if the room uses an unknown room version. + Typically this happens if support for the room's version has been + removed from Synapse. + """ + room_version_id = await self.get_room_version_id(room_id) + v = KNOWN_ROOM_VERSIONS.get(room_version_id) + + if not v: + raise UnsupportedRoomVersionError( + "Room %s uses a room version %s which is no longer supported" + % (room_id, room_version_id) + ) + + return v + + @cached(max_entries=10000) + async def get_room_version_id(self, room_id: str) -> str: + """Get the room_version of a given room + + Raises: + NotFoundError: if the room is unknown + """ + + # First we try looking up room version from the database, but for old + # rooms we might not have added the room version to it yet so we fall + # back to previous behaviour and look in current state events. + + # We really should have an entry in the rooms table for every room we + # care about, but let's be a bit paranoid (at least while the background + # update is happening) to avoid breaking existing rooms. + version = await self.db_pool.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="room_version", + desc="get_room_version", + allow_none=True, + ) + + if version is not None: + return version + + # Retrieve the room's create event + create_event = await self.get_create_event_for_room(room_id) + return create_event.content.get("room_version", "1") + + async def get_room_predecessor(self, room_id: str) -> Optional[dict]: + """Get the predecessor of an upgraded room if it exists. + Otherwise return None. + + Args: + room_id: The room ID. + + Returns: + A dictionary containing the structure of the predecessor + field from the room's create event. The structure is subject to other servers, + but it is expected to be: + * room_id (str): The room ID of the predecessor room + * event_id (str): The ID of the tombstone event in the predecessor room + + None if a predecessor key is not found, or is not a dictionary. + + Raises: + NotFoundError if the given room is unknown + """ + # Retrieve the room's create event + create_event = await self.get_create_event_for_room(room_id) + + # Retrieve the predecessor key of the create event + predecessor = create_event.content.get("predecessor", None) + + # Ensure the key is a dictionary + if not isinstance(predecessor, collections.abc.Mapping): + return None + + return predecessor + + async def get_create_event_for_room(self, room_id: str) -> EventBase: + """Get the create state event for a room. + + Args: + room_id: The room ID. + + Returns: + The room creation event. + + Raises: + NotFoundError if the room is unknown + """ + state_ids = await self.get_current_state_ids(room_id) + create_id = state_ids.get((EventTypes.Create, "")) + + # If we can't find the create event, assume we've hit a dead end + if not create_id: + raise NotFoundError("Unknown room %s" % (room_id,)) + + # Retrieve the room's create event and return + create_event = await self.get_event(create_id) + return create_event + + @cached(max_entries=100000, iterable=True) + def get_current_state_ids(self, room_id): + """Get the current state event ids for a room based on the + current_state_events table. + + Args: + room_id (str) + + Returns: + deferred: dict of (type, state_key) -> event_id + """ + + def _get_current_state_ids_txn(txn): + txn.execute( + """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """, + (room_id,), + ) + + return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} + + return self.db_pool.runInteraction( + "get_current_state_ids", _get_current_state_ids_txn + ) + + # FIXME: how should this be cached? + def get_filtered_current_state_ids( + self, room_id: str, state_filter: StateFilter = StateFilter.all() + ): + """Get the current state event of a given type for a room based on the + current_state_events table. This may not be as up-to-date as the result + of doing a fresh state resolution as per state_handler.get_current_state + + Args: + room_id + state_filter: The state filter used to fetch state + from the database. + + Returns: + defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. + """ + + where_clause, where_args = state_filter.make_sql_filter_clause() + + if not where_clause: + # We delegate to the cached version + return self.get_current_state_ids(room_id) + + def _get_filtered_current_state_ids_txn(txn): + results = {} + sql = """ + SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """ + + if where_clause: + sql += " AND (%s)" % (where_clause,) + + args = [room_id] + args.extend(where_args) + txn.execute(sql, args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + + return results + + return self.db_pool.runInteraction( + "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn + ) + + async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: + """Get canonical alias for room, if any + + Args: + room_id: The room ID + + Returns: + The canonical alias, if any + """ + + state = await self.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) + ) + + event_id = state.get((EventTypes.CanonicalAlias, "")) + if not event_id: + return + + event = await self.get_event(event_id, allow_none=True) + if not event: + return + + return event.content.get("canonical_alias") + + @cached(max_entries=50000) + def _get_state_group_for_event(self, event_id): + return self.db_pool.simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @cachedList( + cached_method_name="_get_state_group_for_event", + list_name="event_ids", + num_args=1, + inlineCallbacks=True, + ) + def _get_state_group_for_events(self, event_ids): + """Returns mapping event_id -> state_group + """ + rows = yield self.db_pool.simple_select_many_batch( + table="event_to_state_groups", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=("event_id", "state_group"), + desc="_get_state_group_for_events", + ) + + return {row["event_id"]: row["state_group"] for row in rows} + + async def get_referenced_state_groups( + self, state_groups: Iterable[int] + ) -> Set[int]: + """Check if the state groups are referenced by events. + + Args: + state_groups + + Returns: + The subset of state groups that are referenced. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="event_to_state_groups", + column="state_group", + iterable=state_groups, + keyvalues={}, + retcols=("DISTINCT state_group",), + desc="get_referenced_state_groups", + ) + + return {row["state_group"] for row in rows} + + +class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): + + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" + DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.server_name = hs.hostname + + self.db_pool.updates.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + self.db_pool.updates.register_background_index_update( + self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, + index_name="event_to_state_groups_sg_index", + table="event_to_state_groups", + columns=["state_group"], + ) + self.db_pool.updates.register_background_update_handler( + self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms, + ) + + async def _background_remove_left_rooms(self, progress, batch_size): + """Background update to delete rows from `current_state_events` and + `event_forward_extremities` tables of rooms that the server is no + longer joined to. + """ + + last_room_id = progress.get("last_room_id", "") + + def _background_remove_left_rooms_txn(txn): + # get a batch of room ids to consider + sql = """ + SELECT DISTINCT room_id FROM current_state_events + WHERE room_id > ? ORDER BY room_id LIMIT ? + """ + + txn.execute(sql, (last_room_id, batch_size)) + room_ids = [row[0] for row in txn] + if not room_ids: + return True, set() + + ########################################################################### + # + # exclude rooms where we have active members + + sql = """ + SELECT room_id + FROM local_current_membership + WHERE + room_id > ? AND room_id <= ? + AND membership = 'join' + GROUP BY room_id + """ + + txn.execute(sql, (last_room_id, room_ids[-1])) + joined_room_ids = {row[0] for row in txn} + to_delete = set(room_ids) - joined_room_ids + + ########################################################################### + # + # exclude rooms which we are in the process of constructing; these otherwise + # qualify as "rooms with no local users", and would have their + # forward extremities cleaned up. + + # the following query will return a list of rooms which have forward + # extremities that are *not* also the create event in the room - ie + # those that are not being created currently. + + sql = """ + SELECT DISTINCT efe.room_id + FROM event_forward_extremities efe + LEFT JOIN current_state_events cse ON + cse.event_id = efe.event_id + AND cse.type = 'm.room.create' + AND cse.state_key = '' + WHERE + cse.event_id IS NULL + AND efe.room_id > ? AND efe.room_id <= ? + """ + + txn.execute(sql, (last_room_id, room_ids[-1])) + + # build a set of those rooms within `to_delete` that do not appear in + # the above, leaving us with the rooms in `to_delete` that *are* being + # created. + creating_rooms = to_delete.difference(row[0] for row in txn) + logger.info("skipping rooms which are being created: %s", creating_rooms) + + # now remove the rooms being created from the list of those to delete. + # + # (we could have just taken the intersection of `to_delete` with the result + # of the sql query, but it's useful to be able to log `creating_rooms`; and + # having done so, it's quicker to remove the (few) creating rooms from + # `to_delete` than it is to form the intersection with the (larger) list of + # not-creating-rooms) + + to_delete -= creating_rooms + + ########################################################################### + # + # now clear the state for the rooms + + logger.info("Deleting current state left rooms: %r", to_delete) + + # First we get all users that we still think were joined to the + # room. This is so that we can mark those device lists as + # potentially stale, since there may have been a period where the + # server didn't share a room with the remote user and therefore may + # have missed any device updates. + rows = self.db_pool.simple_select_many_txn( + txn, + table="current_state_events", + column="room_id", + iterable=to_delete, + keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN}, + retcols=("state_key",), + ) + + potentially_left_users = {row["state_key"] for row in rows} + + # Now lets actually delete the rooms from the DB. + self.db_pool.simple_delete_many_txn( + txn, + table="current_state_events", + column="room_id", + iterable=to_delete, + keyvalues={}, + ) + + self.db_pool.simple_delete_many_txn( + txn, + table="event_forward_extremities", + column="room_id", + iterable=to_delete, + keyvalues={}, + ) + + self.db_pool.updates._background_update_progress_txn( + txn, + self.DELETE_CURRENT_STATE_UPDATE_NAME, + {"last_room_id": room_ids[-1]}, + ) + + return False, potentially_left_users + + finished, potentially_left_users = await self.db_pool.runInteraction( + "_background_remove_left_rooms", _background_remove_left_rooms_txn + ) + + if finished: + await self.db_pool.updates._end_background_update( + self.DELETE_CURRENT_STATE_UPDATE_NAME + ) + + # Now go and check if we still share a room with the remote users in + # the deleted rooms. If not mark their device lists as stale. + joined_users = await self.get_users_server_still_shares_room_with( + potentially_left_users + ) + + for user_id in potentially_left_users - joined_users: + await self.mark_remote_user_device_list_as_unsubscribed(user_id) + + return batch_size + + +class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + def __init__(self, database: DatabasePool, db_conn, hs): + super(StateStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py new file mode 100644 index 0000000000..0d963c98ff --- /dev/null +++ b/synapse/storage/databases/main/state_deltas.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class StateDeltasStore(SQLBaseStore): + def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): + """Fetch a list of room state changes since the given stream id + + Each entry in the result contains the following fields: + - stream_id (int) + - room_id (str) + - type (str): event type + - state_key (str): + - event_id (str|None): new event_id for this state key. None if the + state has been deleted. + - prev_event_id (str|None): previous event_id for this state key. None + if it's new state. + + Args: + prev_stream_id (int): point to get changes since (exclusive) + max_stream_id (int): the point that we know has been correctly persisted + - ie, an upper limit to return changes from. + + Returns: + Deferred[tuple[int, list[dict]]: A tuple consisting of: + - the stream id which these results go up to + - list of current_state_delta_stream rows. If it is empty, we are + up to date. + """ + prev_stream_id = int(prev_stream_id) + + # check we're not going backwards + assert prev_stream_id <= max_stream_id + + if not self._curr_state_delta_stream_cache.has_any_entity_changed( + prev_stream_id + ): + # if the CSDs haven't changed between prev_stream_id and now, we + # know for certain that they haven't changed between prev_stream_id and + # max_stream_id. + return defer.succeed((max_stream_id, [])) + + def get_current_state_deltas_txn(txn): + # First we calculate the max stream id that will give us less than + # N results. + # We arbitarily limit to 100 stream_id entries to ensure we don't + # select toooo many. + sql = """ + SELECT stream_id, count(*) + FROM current_state_delta_stream + WHERE stream_id > ? AND stream_id <= ? + GROUP BY stream_id + ORDER BY stream_id ASC + LIMIT 100 + """ + txn.execute(sql, (prev_stream_id, max_stream_id)) + + total = 0 + + for stream_id, count in txn: + total += count + if total > 100: + # We arbitarily limit to 100 entries to ensure we don't + # select toooo many. + logger.debug( + "Clipping current_state_delta_stream rows to stream_id %i", + stream_id, + ) + clipped_stream_id = stream_id + break + else: + # if there's no problem, we may as well go right up to the max_stream_id + clipped_stream_id = max_stream_id + + # Now actually get the deltas + sql = """ + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute(sql, (prev_stream_id, clipped_stream_id)) + return clipped_stream_id, self.db_pool.cursor_to_dict(txn) + + return self.db_pool.runInteraction( + "get_current_state_deltas", get_current_state_deltas_txn + ) + + def _get_max_stream_id_in_current_state_deltas_txn(self, txn): + return self.db_pool.simple_select_one_onecol_txn( + txn, + table="current_state_delta_stream", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), -1)", + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self.db_pool.runInteraction( + "get_max_stream_id_in_current_state_deltas", + self._get_max_stream_id_in_current_state_deltas_txn, + ) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py new file mode 100644 index 0000000000..802c9019b9 --- /dev/null +++ b/synapse/storage/databases/main/stats.py @@ -0,0 +1,886 @@ +# -*- coding: utf-8 -*- +# Copyright 2018, 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from itertools import chain +from typing import Tuple + +from twisted.internet.defer import DeferredLock + +from synapse.api.constants import EventTypes, Membership +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.state_deltas import StateDeltasStore +from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + +# these fields track absolutes (e.g. total number of rooms on the server) +# You can think of these as Prometheus Gauges. +# You can draw these stats on a line graph. +# Example: number of users in a room +ABSOLUTE_STATS_FIELDS = { + "room": ( + "current_state_events", + "joined_members", + "invited_members", + "left_members", + "banned_members", + "local_users_in_room", + ), + "user": ("joined_rooms",), +} + +# these fields are per-timeslice and so should be reset to 0 upon a new slice +# You can draw these stats on a histogram. +# Example: number of events sent locally during a time slice +PER_SLICE_FIELDS = { + "room": ("total_events", "total_event_bytes"), + "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"), +} + +TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} + +# these are the tables (& ID columns) which contain our actual subjects +TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} + + +class StatsStore(StateDeltasStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(StatsStore, self).__init__(database, db_conn, hs) + + self.server_name = hs.hostname + self.clock = self.hs.get_clock() + self.stats_enabled = hs.config.stats_enabled + self.stats_bucket_size = hs.config.stats_bucket_size + + self.stats_delta_processing_lock = DeferredLock() + + self.db_pool.updates.register_background_update_handler( + "populate_stats_process_rooms", self._populate_stats_process_rooms + ) + self.db_pool.updates.register_background_update_handler( + "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2 + ) + self.db_pool.updates.register_background_update_handler( + "populate_stats_process_users", self._populate_stats_process_users + ) + # we no longer need to perform clean-up, but we will give ourselves + # the potential to reintroduce it in the future – so documentation + # will still encourage the use of this no-op handler. + self.db_pool.updates.register_noop_background_update("populate_stats_cleanup") + self.db_pool.updates.register_noop_background_update("populate_stats_prepare") + + def quantise_stats_time(self, ts): + """ + Quantises a timestamp to be a multiple of the bucket size. + + Args: + ts (int): the timestamp to quantise, in milliseconds since the Unix + Epoch + + Returns: + int: a timestamp which + - is divisible by the bucket size; + - is no later than `ts`; and + - is the largest such timestamp. + """ + return (ts // self.stats_bucket_size) * self.stats_bucket_size + + async def _populate_stats_process_users(self, progress, batch_size): + """ + This is a background update which regenerates statistics for users. + """ + if not self.stats_enabled: + await self.db_pool.updates._end_background_update( + "populate_stats_process_users" + ) + return 1 + + last_user_id = progress.get("last_user_id", "") + + def _get_next_batch(txn): + sql = """ + SELECT DISTINCT name FROM users + WHERE name > ? + ORDER BY name ASC + LIMIT ? + """ + txn.execute(sql, (last_user_id, batch_size)) + return [r for r, in txn] + + users_to_work_on = await self.db_pool.runInteraction( + "_populate_stats_process_users", _get_next_batch + ) + + # No more rooms -- complete the transaction. + if not users_to_work_on: + await self.db_pool.updates._end_background_update( + "populate_stats_process_users" + ) + return 1 + + for user_id in users_to_work_on: + await self._calculate_and_set_initial_state_for_user(user_id) + progress["last_user_id"] = user_id + + await self.db_pool.runInteraction( + "populate_stats_process_users", + self.db_pool.updates._background_update_progress_txn, + "populate_stats_process_users", + progress, + ) + + return len(users_to_work_on) + + async def _populate_stats_process_rooms(self, progress, batch_size): + """ + This was a background update which regenerated statistics for rooms. + + It has been replaced by StatsStore._populate_stats_process_rooms_2. This background + job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure + someone upgrading from ? + ORDER BY room_id ASC + LIMIT ? + """ + txn.execute(sql, (last_room_id, batch_size)) + return [r for r, in txn] + + rooms_to_work_on = await self.db_pool.runInteraction( + "populate_stats_rooms_2_get_batch", _get_next_batch + ) + + # No more rooms -- complete the transaction. + if not rooms_to_work_on: + await self.db_pool.updates._end_background_update( + "populate_stats_process_rooms_2" + ) + return 1 + + for room_id in rooms_to_work_on: + await self._calculate_and_set_initial_state_for_room(room_id) + progress["last_room_id"] = room_id + + await self.db_pool.runInteraction( + "_populate_stats_process_rooms_2", + self.db_pool.updates._background_update_progress_txn, + "populate_stats_process_rooms_2", + progress, + ) + + return len(rooms_to_work_on) + + def get_stats_positions(self): + """ + Returns the stats processor positions. + """ + return self.db_pool.simple_select_one_onecol( + table="stats_incremental_position", + keyvalues={}, + retcol="stream_id", + desc="stats_incremental_position", + ) + + def update_room_state(self, room_id, fields): + """ + Args: + room_id (str) + fields (dict[str:Any]) + """ + + # For whatever reason some of the fields may contain null bytes, which + # postgres isn't a fan of, so we replace those fields with null. + for col in ( + "join_rules", + "history_visibility", + "encryption", + "name", + "topic", + "avatar", + "canonical_alias", + ): + field = fields.get(col) + if field and "\0" in field: + fields[col] = None + + return self.db_pool.simple_upsert( + table="room_stats_state", + keyvalues={"room_id": room_id}, + values=fields, + desc="update_room_state", + ) + + def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): + """ + Get statistics for a given subject. + + Args: + stats_type (str): The type of subject + stats_id (str): The ID of the subject (e.g. room_id or user_id) + start (int): Pagination start. Number of entries, not timestamp. + size (int): How many entries to return. + + Returns: + Deferred[list[dict]], where the dict has the keys of + ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". + """ + return self.db_pool.runInteraction( + "get_statistics_for_subject", + self._get_statistics_for_subject_txn, + stats_type, + stats_id, + start, + size, + ) + + def _get_statistics_for_subject_txn( + self, txn, stats_type, stats_id, start, size=100 + ): + """ + Transaction-bound version of L{get_statistics_for_subject}. + """ + + table, id_col = TYPE_TO_TABLE[stats_type] + selected_columns = list( + ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] + ) + + slice_list = self.db_pool.simple_select_list_paginate_txn( + txn, + table + "_historical", + "end_ts", + start, + size, + retcols=selected_columns + ["bucket_size", "end_ts"], + keyvalues={id_col: stats_id}, + order_direction="DESC", + ) + + return slice_list + + @cached() + def get_earliest_token_for_stats(self, stats_type, id): + """ + Fetch the "earliest token". This is used by the room stats delta + processor to ignore deltas that have been processed between the + start of the background task and any particular room's stats + being calculated. + + Returns: + Deferred[int] + """ + table, id_col = TYPE_TO_TABLE[stats_type] + + return self.db_pool.simple_select_one_onecol( + "%s_current" % (table,), + keyvalues={id_col: id}, + retcol="completed_delta_stream_id", + allow_none=True, + ) + + def bulk_update_stats_delta(self, ts, updates, stream_id): + """Bulk update stats tables for a given stream_id and updates the stats + incremental position. + + Args: + ts (int): Current timestamp in ms + updates(dict[str, dict[str, dict[str, Counter]]]): The updates to + commit as a mapping stats_type -> stats_id -> field -> delta. + stream_id (int): Current position. + + Returns: + Deferred + """ + + def _bulk_update_stats_delta_txn(txn): + for stats_type, stats_updates in updates.items(): + for stats_id, fields in stats_updates.items(): + logger.debug( + "Updating %s stats for %s: %s", stats_type, stats_id, fields + ) + self._update_stats_delta_txn( + txn, + ts=ts, + stats_type=stats_type, + stats_id=stats_id, + fields=fields, + complete_with_stream_id=stream_id, + ) + + self.db_pool.simple_update_one_txn( + txn, + table="stats_incremental_position", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + ) + + return self.db_pool.runInteraction( + "bulk_update_stats_delta", _bulk_update_stats_delta_txn + ) + + def update_stats_delta( + self, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id, + absolute_field_overrides=None, + ): + """ + Updates the statistics for a subject, with a delta (difference/relative + change). + + Args: + ts (int): timestamp of the change + stats_type (str): "room" or "user" – the kind of subject + stats_id (str): the subject's ID (room ID or user ID) + fields (dict[str, int]): Deltas of stats values. + complete_with_stream_id (int, optional): + If supplied, converts an incomplete row into a complete row, + with the supplied stream_id marked as the stream_id where the + row was completed. + absolute_field_overrides (dict[str, int]): Current stats values + (i.e. not deltas) of absolute fields. + Does not work with per-slice fields. + """ + + return self.db_pool.runInteraction( + "update_stats_delta", + self._update_stats_delta_txn, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id=complete_with_stream_id, + absolute_field_overrides=absolute_field_overrides, + ) + + def _update_stats_delta_txn( + self, + txn, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id, + absolute_field_overrides=None, + ): + if absolute_field_overrides is None: + absolute_field_overrides = {} + + table, id_col = TYPE_TO_TABLE[stats_type] + + quantised_ts = self.quantise_stats_time(int(ts)) + end_ts = quantised_ts + self.stats_bucket_size + + # Lets be paranoid and check that all the given field names are known + abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type] + slice_field_names = PER_SLICE_FIELDS[stats_type] + for field in chain(fields.keys(), absolute_field_overrides.keys()): + if field not in abs_field_names and field not in slice_field_names: + # guard against potential SQL injection dodginess + raise ValueError( + "%s is not a recognised field" + " for stats type %s" % (field, stats_type) + ) + + # Per slice fields do not get added to the _current table + + # This calculates the deltas (`field = field + ?` values) + # for absolute fields, + # * defaulting to 0 if not specified + # (required for the INSERT part of upserting to work) + # * omitting overrides specified in `absolute_field_overrides` + deltas_of_absolute_fields = { + key: fields.get(key, 0) + for key in abs_field_names + if key not in absolute_field_overrides + } + + # Keep the delta stream ID field up to date + absolute_field_overrides = absolute_field_overrides.copy() + absolute_field_overrides["completed_delta_stream_id"] = complete_with_stream_id + + # first upsert the `_current` table + self._upsert_with_additive_relatives_txn( + txn=txn, + table=table + "_current", + keyvalues={id_col: stats_id}, + absolutes=absolute_field_overrides, + additive_relatives=deltas_of_absolute_fields, + ) + + per_slice_additive_relatives = { + key: fields.get(key, 0) for key in slice_field_names + } + self._upsert_copy_from_table_with_additive_relatives_txn( + txn=txn, + into_table=table + "_historical", + keyvalues={id_col: stats_id}, + extra_dst_insvalues={"bucket_size": self.stats_bucket_size}, + extra_dst_keyvalues={"end_ts": end_ts}, + additive_relatives=per_slice_additive_relatives, + src_table=table + "_current", + copy_columns=abs_field_names, + ) + + def _upsert_with_additive_relatives_txn( + self, txn, table, keyvalues, absolutes, additive_relatives + ): + """Used to update values in the stats tables. + + This is basically a slightly convoluted upsert that *adds* to any + existing rows. + + Args: + txn + table (str): Table name + keyvalues (dict[str, any]): Row-identifying key values + absolutes (dict[str, any]): Absolute (set) fields + additive_relatives (dict[str, int]): Fields that will be added onto + if existing row present. + """ + if self.database_engine.can_native_upsert: + absolute_updates = [ + "%(field)s = EXCLUDED.%(field)s" % {"field": field} + for field in absolutes.keys() + ] + + relative_updates = [ + "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s" + % {"table": table, "field": field} + for field in additive_relatives.keys() + ] + + insert_cols = [] + qargs = [] + + for (key, val) in chain( + keyvalues.items(), absolutes.items(), additive_relatives.items() + ): + insert_cols.append(key) + qargs.append(val) + + sql = """ + INSERT INTO %(table)s (%(insert_cols_cs)s) + VALUES (%(insert_vals_qs)s) + ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s + """ % { + "table": table, + "insert_cols_cs": ", ".join(insert_cols), + "insert_vals_qs": ", ".join( + ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives)) + ), + "key_columns": ", ".join(keyvalues), + "updates": ", ".join(chain(absolute_updates, relative_updates)), + } + + txn.execute(sql, qargs) + else: + self.database_engine.lock_table(txn, table) + retcols = list(chain(absolutes.keys(), additive_relatives.keys())) + current_row = self.db_pool.simple_select_one_txn( + txn, table, keyvalues, retcols, allow_none=True + ) + if current_row is None: + merged_dict = {**keyvalues, **absolutes, **additive_relatives} + self.db_pool.simple_insert_txn(txn, table, merged_dict) + else: + for (key, val) in additive_relatives.items(): + current_row[key] += val + current_row.update(absolutes) + self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row) + + def _upsert_copy_from_table_with_additive_relatives_txn( + self, + txn, + into_table, + keyvalues, + extra_dst_keyvalues, + extra_dst_insvalues, + additive_relatives, + src_table, + copy_columns, + ): + """Updates the historic stats table with latest updates. + + This involves copying "absolute" fields from the `_current` table, and + adding relative fields to any existing values. + + Args: + txn: Transaction + into_table (str): The destination table to UPSERT the row into + keyvalues (dict[str, any]): Row-identifying key values + extra_dst_keyvalues (dict[str, any]): Additional keyvalues + for `into_table`. + extra_dst_insvalues (dict[str, any]): Additional values to insert + on new row creation for `into_table`. + additive_relatives (dict[str, any]): Fields that will be added onto + if existing row present. (Must be disjoint from copy_columns.) + src_table (str): The source table to copy from + copy_columns (iterable[str]): The list of columns to copy + """ + if self.database_engine.can_native_upsert: + ins_columns = chain( + keyvalues, + copy_columns, + additive_relatives, + extra_dst_keyvalues, + extra_dst_insvalues, + ) + sel_exprs = chain( + keyvalues, + copy_columns, + ( + "?" + for _ in chain( + additive_relatives, extra_dst_keyvalues, extra_dst_insvalues + ) + ), + ) + keyvalues_where = ("%s = ?" % f for f in keyvalues) + + sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns) + sets_ar = ( + "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f) + for f in additive_relatives + ) + + sql = """ + INSERT INTO %(into_table)s (%(ins_columns)s) + SELECT %(sel_exprs)s + FROM %(src_table)s + WHERE %(keyvalues_where)s + ON CONFLICT (%(keyvalues)s) + DO UPDATE SET %(sets)s + """ % { + "into_table": into_table, + "ins_columns": ", ".join(ins_columns), + "sel_exprs": ", ".join(sel_exprs), + "keyvalues_where": " AND ".join(keyvalues_where), + "src_table": src_table, + "keyvalues": ", ".join( + chain(keyvalues.keys(), extra_dst_keyvalues.keys()) + ), + "sets": ", ".join(chain(sets_cc, sets_ar)), + } + + qargs = list( + chain( + additive_relatives.values(), + extra_dst_keyvalues.values(), + extra_dst_insvalues.values(), + keyvalues.values(), + ) + ) + txn.execute(sql, qargs) + else: + self.database_engine.lock_table(txn, into_table) + src_row = self.db_pool.simple_select_one_txn( + txn, src_table, keyvalues, copy_columns + ) + all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} + dest_current_row = self.db_pool.simple_select_one_txn( + txn, + into_table, + keyvalues=all_dest_keyvalues, + retcols=list(chain(additive_relatives.keys(), copy_columns)), + allow_none=True, + ) + + if dest_current_row is None: + merged_dict = { + **keyvalues, + **extra_dst_keyvalues, + **extra_dst_insvalues, + **src_row, + **additive_relatives, + } + self.db_pool.simple_insert_txn(txn, into_table, merged_dict) + else: + for (key, val) in additive_relatives.items(): + src_row[key] = dest_current_row[key] + val + self.db_pool.simple_update_txn( + txn, into_table, all_dest_keyvalues, src_row + ) + + def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): + """Fetches the counts of events in the given range of stream IDs. + + Args: + min_pos (int) + max_pos (int) + + Returns: + Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field + changes. + """ + + return self.db_pool.runInteraction( + "stats_incremental_total_events_and_bytes", + self.get_changes_room_total_events_and_bytes_txn, + min_pos, + max_pos, + ) + + def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos): + """Gets the total_events and total_event_bytes counts for rooms and + senders, in a range of stream_orderings (including backfilled events). + + Args: + txn + low_pos (int): Low stream ordering + high_pos (int): High stream ordering + + Returns: + tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The + room and user deltas for total_events/total_event_bytes in the + format of `stats_id` -> fields + """ + + if low_pos >= high_pos: + # nothing to do here. + return {}, {} + + if isinstance(self.database_engine, PostgresEngine): + new_bytes_expression = "OCTET_LENGTH(json)" + else: + new_bytes_expression = "LENGTH(CAST(json AS BLOB))" + + sql = """ + SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes + FROM events INNER JOIN event_json USING (event_id) + WHERE (? < stream_ordering AND stream_ordering <= ?) + OR (? <= stream_ordering AND stream_ordering <= ?) + GROUP BY events.room_id + """ % ( + new_bytes_expression, + ) + + txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) + + room_deltas = { + room_id: {"total_events": new_events, "total_event_bytes": new_bytes} + for room_id, new_events, new_bytes in txn + } + + sql = """ + SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes + FROM events INNER JOIN event_json USING (event_id) + WHERE (? < stream_ordering AND stream_ordering <= ?) + OR (? <= stream_ordering AND stream_ordering <= ?) + GROUP BY events.sender + """ % ( + new_bytes_expression, + ) + + txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) + + user_deltas = { + user_id: {"total_events": new_events, "total_event_bytes": new_bytes} + for user_id, new_events, new_bytes in txn + if self.hs.is_mine_id(user_id) + } + + return room_deltas, user_deltas + + async def _calculate_and_set_initial_state_for_room( + self, room_id: str + ) -> Tuple[dict, dict, int]: + """Calculate and insert an entry into room_stats_current. + + Args: + room_id: The room ID under calculation. + + Returns: + A tuple of room state, membership counts and stream position. + """ + + def _fetch_current_state_stats(txn): + pos = self.get_room_max_stream_ordering() + + rows = self.db_pool.simple_select_many_txn( + txn, + table="current_state_events", + column="type", + iterable=[ + EventTypes.Create, + EventTypes.JoinRules, + EventTypes.RoomHistoryVisibility, + EventTypes.RoomEncryption, + EventTypes.Name, + EventTypes.Topic, + EventTypes.RoomAvatar, + EventTypes.CanonicalAlias, + ], + keyvalues={"room_id": room_id, "state_key": ""}, + retcols=["event_id"], + ) + + event_ids = [row["event_id"] for row in rows] + + txn.execute( + """ + SELECT membership, count(*) FROM current_state_events + WHERE room_id = ? AND type = 'm.room.member' + GROUP BY membership + """, + (room_id,), + ) + membership_counts = {membership: cnt for membership, cnt in txn} + + txn.execute( + """ + SELECT COALESCE(count(*), 0) FROM current_state_events + WHERE room_id = ? + """, + (room_id,), + ) + + (current_state_events_count,) = txn.fetchone() + + users_in_room = self.get_users_in_room_txn(txn, room_id) + + return ( + event_ids, + membership_counts, + current_state_events_count, + users_in_room, + pos, + ) + + ( + event_ids, + membership_counts, + current_state_events_count, + users_in_room, + pos, + ) = await self.db_pool.runInteraction( + "get_initial_state_for_room", _fetch_current_state_stats + ) + + state_event_map = await self.get_events(event_ids, get_prev_content=False) + + room_state = { + "join_rules": None, + "history_visibility": None, + "encryption": None, + "name": None, + "topic": None, + "avatar": None, + "canonical_alias": None, + "is_federatable": True, + } + + for event in state_event_map.values(): + if event.type == EventTypes.JoinRules: + room_state["join_rules"] = event.content.get("join_rule") + elif event.type == EventTypes.RoomHistoryVisibility: + room_state["history_visibility"] = event.content.get( + "history_visibility" + ) + elif event.type == EventTypes.RoomEncryption: + room_state["encryption"] = event.content.get("algorithm") + elif event.type == EventTypes.Name: + room_state["name"] = event.content.get("name") + elif event.type == EventTypes.Topic: + room_state["topic"] = event.content.get("topic") + elif event.type == EventTypes.RoomAvatar: + room_state["avatar"] = event.content.get("url") + elif event.type == EventTypes.CanonicalAlias: + room_state["canonical_alias"] = event.content.get("alias") + elif event.type == EventTypes.Create: + room_state["is_federatable"] = ( + event.content.get("m.federate", True) is True + ) + + await self.update_room_state(room_id, room_state) + + local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] + + await self.update_stats_delta( + ts=self.clock.time_msec(), + stats_type="room", + stats_id=room_id, + fields={}, + complete_with_stream_id=pos, + absolute_field_overrides={ + "current_state_events": current_state_events_count, + "joined_members": membership_counts.get(Membership.JOIN, 0), + "invited_members": membership_counts.get(Membership.INVITE, 0), + "left_members": membership_counts.get(Membership.LEAVE, 0), + "banned_members": membership_counts.get(Membership.BAN, 0), + "local_users_in_room": len(local_users_in_room), + }, + ) + + async def _calculate_and_set_initial_state_for_user(self, user_id): + def _calculate_and_set_initial_state_for_user_txn(txn): + pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) + + txn.execute( + """ + SELECT COUNT(distinct room_id) FROM current_state_events + WHERE type = 'm.room.member' AND state_key = ? + AND membership = 'join' + """, + (user_id,), + ) + (count,) = txn.fetchone() + return count, pos + + joined_rooms, pos = await self.db_pool.runInteraction( + "calculate_and_set_initial_state_for_user", + _calculate_and_set_initial_state_for_user_txn, + ) + + await self.update_stats_delta( + ts=self.clock.time_msec(), + stats_type="user", + stats_id=user_id, + fields={}, + complete_with_stream_id=pos, + absolute_field_overrides={"joined_rooms": joined_rooms}, + ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py new file mode 100644 index 0000000000..aaf225894e --- /dev/null +++ b/synapse/storage/databases/main/stream.py @@ -0,0 +1,1064 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This module is responsible for getting events from the DB for pagination +and event streaming. + +The order it returns events in depend on whether we are streaming forwards or +are paginating backwards. We do this because we want to handle out of order +messages nicely, while still returning them in the correct order when we +paginate bacwards. + +This is implemented by keeping two ordering columns: stream_ordering and +topological_ordering. Stream ordering is basically insertion/received order +(except for events from backfill requests). The topological_ordering is a +weak ordering of events based on the pdu graph. + +This means that we have to have two different types of tokens, depending on +what sort order was used: + - stream tokens are of the form: "s%d", which maps directly to the column + - topological tokems: "t%d-%d", where the integers map to the topological + and stream ordering columns respectively. +""" + +import abc +import logging +from collections import namedtuple +from typing import Optional + +from twisted.internet import defer + +from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.types import RoomStreamToken +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +MAX_STREAM_SIZE = 1000 + + +_STREAM_TOKEN = "stream" +_TOPOLOGICAL_TOKEN = "topological" + + +# Used as return values for pagination APIs +_EventDictReturn = namedtuple( + "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") +) + + +def generate_pagination_where_clause( + direction, column_names, from_token, to_token, engine +): + """Creates an SQL expression to bound the columns by the pagination + tokens. + + For example creates an SQL expression like: + + (6, 7) >= (topological_ordering, stream_ordering) + AND (5, 3) < (topological_ordering, stream_ordering) + + would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + + Note that tokens are considered to be after the row they are in, e.g. if + a row A has a token T, then we consider A to be before T. This convention + is important when figuring out inequalities for the generated SQL, and + produces the following result: + - If paginating forwards then we exclude any rows matching the from + token, but include those that match the to token. + - If paginating backwards then we include any rows matching the from + token, but include those that match the to token. + + Args: + direction (str): Whether we're paginating backwards("b") or + forwards ("f"). + column_names (tuple[str, str]): The column names to bound. Must *not* + be user defined as these get inserted directly into the SQL + statement without escapes. + from_token (tuple[int, int]|None): The start point for the pagination. + This is an exclusive minimum bound if direction is "f", and an + inclusive maximum bound if direction is "b". + to_token (tuple[int, int]|None): The endpoint point for the pagination. + This is an inclusive maximum bound if direction is "f", and an + exclusive minimum bound if direction is "b". + engine: The database engine to generate the clauses for + + Returns: + str: The sql expression + """ + assert direction in ("b", "f") + + where_clause = [] + if from_token: + where_clause.append( + _make_generic_sql_bound( + bound=">=" if direction == "b" else "<", + column_names=column_names, + values=from_token, + engine=engine, + ) + ) + + if to_token: + where_clause.append( + _make_generic_sql_bound( + bound="<" if direction == "b" else ">=", + column_names=column_names, + values=to_token, + engine=engine, + ) + ) + + return " AND ".join(where_clause) + + +def _make_generic_sql_bound(bound, column_names, values, engine): + """Create an SQL expression that bounds the given column names by the + values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. + + Only works with two columns. + + Older versions of SQLite don't support that syntax so we have to expand it + out manually. + + Args: + bound (str): The comparison operator to use. One of ">", "<", ">=", + "<=", where the values are on the left and columns on the right. + names (tuple[str, str]): The column names. Must *not* be user defined + as these get inserted directly into the SQL statement without + escapes. + values (tuple[int|None, int]): The values to bound the columns by. If + the first value is None then only creates a bound on the second + column. + engine: The database engine to generate the SQL for + + Returns: + str + """ + + assert bound in (">", "<", ">=", "<=") + + name1, name2 = column_names + val1, val2 = values + + if val1 is None: + val2 = int(val2) + return "(%d %s %s)" % (val2, bound, name2) + + val1 = int(val1) + val2 = int(val2) + + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x < a) OR (x=a AND y ? AND stream_ordering <= ?" + " ORDER BY stream_ordering %s LIMIT ?" + ) % (order,) + txn.execute(sql, (room_id, from_id, to_id, limit)) + + rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + return rows + + rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f) + + ret = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(ret, rows, topo_order=from_id is None) + + if order.lower() == "desc": + ret.reverse() + + if rows: + key = "s%d" % min(r.stream_ordering for r in rows) + else: + # Assume we didn't get anything because there was nothing to + # get. + key = from_key + + return ret, key + + @defer.inlineCallbacks + def get_membership_changes_for_user(self, user_id, from_key, to_key): + from_id = RoomStreamToken.parse_stream_token(from_key).stream + to_id = RoomStreamToken.parse_stream_token(to_key).stream + + if from_key == to_key: + return [] + + if from_id: + has_changed = self._membership_stream_cache.has_entity_changed( + user_id, int(from_id) + ) + if not has_changed: + return [] + + def f(txn): + sql = ( + "SELECT m.event_id, stream_ordering FROM events AS e," + " room_memberships AS m" + " WHERE e.event_id = m.event_id" + " AND m.user_id = ?" + " AND e.stream_ordering > ? AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + ) + txn.execute(sql, (user_id, from_id, to_id)) + + rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + + return rows + + rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f) + + ret = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(ret, rows, topo_order=False) + + return ret + + @defer.inlineCallbacks + def get_recent_events_for_room(self, room_id, limit, end_token): + """Get the most recent events in the room in topological ordering. + + Args: + room_id (str) + limit (int) + end_token (str): The stream token representing now. + + Returns: + Deferred[tuple[list[FrozenEvent], str]]: Returns a list of + events and a token pointing to the start of the returned + events. + The events returned are in ascending order. + """ + + rows, token = yield self.get_recent_event_ids_for_room( + room_id, limit, end_token + ) + + events = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(events, rows) + + return (events, token) + + @defer.inlineCallbacks + def get_recent_event_ids_for_room(self, room_id, limit, end_token): + """Get the most recent events in the room in topological ordering. + + Args: + room_id (str) + limit (int) + end_token (str): The stream token representing now. + + Returns: + Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of + _EventDictReturn and a token pointing to the start of the returned + events. + The events returned are in ascending order. + """ + # Allow a zero limit here, and no-op. + if limit == 0: + return [], end_token + + end_token = RoomStreamToken.parse(end_token) + + rows, token = yield self.db_pool.runInteraction( + "get_recent_event_ids_for_room", + self._paginate_room_events_txn, + room_id, + from_token=end_token, + limit=limit, + ) + + # We want to return the results in ascending order. + rows.reverse() + + return rows, token + + def get_room_event_before_stream_ordering(self, room_id, stream_ordering): + """Gets details of the first event in a room at or before a stream ordering + + Args: + room_id (str): + stream_ordering (int): + + Returns: + Deferred[(int, int, str)]: + (stream ordering, topological ordering, event_id) + """ + + def _f(txn): + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering <= ?" + " AND NOT outlier" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + txn.execute(sql, (room_id, stream_ordering)) + return txn.fetchone() + + return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f) + + async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: + """Returns the current token for rooms stream. + + By default, it returns the current global stream token. Specifying a + `room_id` causes it to return the current room specific topological + token. + """ + token = self.get_room_max_stream_ordering() + if room_id is None: + return "s%d" % (token,) + else: + topo = await self.db_pool.runInteraction( + "_get_max_topological_txn", self._get_max_topological_txn, room_id + ) + return "t%d-%d" % (topo, token) + + def get_stream_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "s%d" stream token. + """ + return self.db_pool.simple_select_one_onecol( + table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" + ).addCallback(lambda row: "s%d" % (row,)) + + def get_topological_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "t%d-%d" topological token. + """ + return self.db_pool.simple_select_one( + table="events", + keyvalues={"event_id": event_id}, + retcols=("stream_ordering", "topological_ordering"), + desc="get_topological_token_for_event", + ).addCallback( + lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) + ) + + def get_max_topological_token(self, room_id, stream_key): + """Get the max topological token in a room before the given stream + ordering. + + Args: + room_id (str) + stream_key (int) + + Returns: + Deferred[int] + """ + sql = ( + "SELECT coalesce(max(topological_ordering), 0) FROM events" + " WHERE room_id = ? AND stream_ordering < ?" + ) + return self.db_pool.execute( + "get_max_topological_token", None, sql, room_id, stream_key + ).addCallback(lambda r: r[0][0] if r else 0) + + def _get_max_topological_txn(self, txn, room_id): + txn.execute( + "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", + (room_id,), + ) + + rows = txn.fetchall() + return rows[0][0] if rows else 0 + + @staticmethod + def _set_before_and_after(events, rows, topo_order=True): + """Inserts ordering information to events' internal metadata from + the DB rows. + + Args: + events (list[FrozenEvent]) + rows (list[_EventDictReturn]) + topo_order (bool): Whether the events were ordered topologically + or by stream ordering. If true then all rows should have a non + null topological_ordering. + """ + for event, row in zip(events, rows): + stream = row.stream_ordering + if topo_order and row.topological_ordering: + topo = row.topological_ordering + else: + topo = None + internal = event.internal_metadata + internal.before = str(RoomStreamToken(topo, stream - 1)) + internal.after = str(RoomStreamToken(topo, stream)) + internal.order = (int(topo) if topo else 0, int(stream)) + + @defer.inlineCallbacks + def get_events_around( + self, room_id, event_id, before_limit, after_limit, event_filter=None + ): + """Retrieve events and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + event_filter (Filter|None) + + Returns: + dict + """ + + results = yield self.db_pool.runInteraction( + "get_events_around", + self._get_events_around_txn, + room_id, + event_id, + before_limit, + after_limit, + event_filter, + ) + + events_before = yield self.get_events_as_list( + list(results["before"]["event_ids"]), get_prev_content=True + ) + + events_after = yield self.get_events_as_list( + list(results["after"]["event_ids"]), get_prev_content=True + ) + + return { + "events_before": events_before, + "events_after": events_after, + "start": results["before"]["token"], + "end": results["after"]["token"], + } + + def _get_events_around_txn( + self, txn, room_id, event_id, before_limit, after_limit, event_filter + ): + """Retrieves event_ids and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + event_filter (Filter|None) + + Returns: + dict + """ + + results = self.db_pool.simple_select_one_txn( + txn, + "events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcols=["stream_ordering", "topological_ordering"], + ) + + # Paginating backwards includes the event at the token, but paginating + # forward doesn't. + before_token = RoomStreamToken( + results["topological_ordering"] - 1, results["stream_ordering"] + ) + + after_token = RoomStreamToken( + results["topological_ordering"], results["stream_ordering"] + ) + + rows, start_token = self._paginate_room_events_txn( + txn, + room_id, + before_token, + direction="b", + limit=before_limit, + event_filter=event_filter, + ) + events_before = [r.event_id for r in rows] + + rows, end_token = self._paginate_room_events_txn( + txn, + room_id, + after_token, + direction="f", + limit=after_limit, + event_filter=event_filter, + ) + events_after = [r.event_id for r in rows] + + return { + "before": {"event_ids": events_before, "token": start_token}, + "after": {"event_ids": events_after, "token": end_token}, + } + + @defer.inlineCallbacks + def get_all_new_events_stream(self, from_id, current_id, limit): + """Get all new events + + Returns all events with from_id < stream_ordering <= current_id. + + Args: + from_id (int): the stream_ordering of the last event we processed + current_id (int): the stream_ordering of the most recently processed event + limit (int): the maximum number of events to return + + Returns: + Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where + `next_id` is the next value to pass as `from_id` (it will either be the + stream_ordering of the last returned event, or, if fewer than `limit` events + were found, `current_id`. + """ + + def get_all_new_events_stream_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id" + " FROM events AS e" + " WHERE" + " ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute(sql, (from_id, current_id, limit)) + rows = txn.fetchall() + + upper_bound = current_id + if len(rows) == limit: + upper_bound = rows[-1][0] + + return upper_bound, [row[1] for row in rows] + + upper_bound, event_ids = yield self.db_pool.runInteraction( + "get_all_new_events_stream", get_all_new_events_stream_txn + ) + + events = yield self.get_events_as_list(event_ids) + + return upper_bound, events + + async def get_federation_out_pos(self, typ: str) -> int: + if self._need_to_reset_federation_stream_positions: + await self.db_pool.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db_pool.simple_select_one_onecol( + table="federation_stream_position", + retcol="stream_id", + keyvalues={"type": typ, "instance_name": self._instance_name}, + desc="get_federation_out_pos", + ) + + async def update_federation_out_pos(self, typ, stream_id): + if self._need_to_reset_federation_stream_positions: + await self.db_pool.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db_pool.simple_update_one( + table="federation_stream_position", + keyvalues={"type": typ, "instance_name": self._instance_name}, + updatevalues={"stream_id": stream_id}, + desc="update_federation_out_pos", + ) + + def _reset_federation_positions_txn(self, txn): + """Fiddles with the `federation_stream_position` table to make it match + the configured federation sender instances during start up. + """ + + # The federation sender instances may have changed, so we need to + # massage the `federation_stream_position` table to have a row per type + # per instance sending federation. If there is a mismatch we update the + # table with the correct rows using the *minimum* stream ID seen. This + # may result in resending of events/EDUs to remote servers, but that is + # preferable to dropping them. + + if not self._send_federation: + return + + # Pull out the configured instances. If we don't have a shard config then + # we assume that we're the only instance sending. + configured_instances = self._federation_shard_config.instances + if not configured_instances: + configured_instances = [self._instance_name] + elif self._instance_name not in configured_instances: + return + + instances_in_table = self.db_pool.simple_select_onecol_txn( + txn, + table="federation_stream_position", + keyvalues={}, + retcol="instance_name", + ) + + if set(instances_in_table) == set(configured_instances): + # Nothing to do + return + + sql = """ + SELECT type, MIN(stream_id) FROM federation_stream_position + GROUP BY type + """ + txn.execute(sql) + min_positions = dict(txn) # Map from type -> min position + + # Ensure we do actually have some values here + assert set(min_positions) == {"federation", "events"} + + sql = """ + DELETE FROM federation_stream_position + WHERE NOT (%s) + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "instance_name", configured_instances + ) + txn.execute(sql % (clause,), args) + + for typ, stream_id in min_positions.items(): + self.db_pool.simple_upsert_txn( + txn, + table="federation_stream_position", + keyvalues={"type": typ, "instance_name": self._instance_name}, + values={"stream_id": stream_id}, + ) + + def has_room_changed_since(self, room_id, stream_id): + return self._events_stream_cache.has_entity_changed(room_id, stream_id) + + def _paginate_room_events_txn( + self, + txn, + room_id, + from_token, + to_token=None, + direction="b", + limit=-1, + event_filter=None, + ): + """Returns list of events before or after a given token. + + Args: + txn + room_id (str) + from_token (RoomStreamToken): The token used to stream from + to_token (RoomStreamToken|None): A token which if given limits the + results to only those before + direction(char): Either 'b' or 'f' to indicate whether we are + paginating forwards or backwards from `from_key`. + limit (int): The maximum number of events to return. + event_filter (Filter|None): If provided filters the events to + those that match the filter. + + Returns: + Deferred[tuple[list[_EventDictReturn], str]]: Returns the results + as a list of _EventDictReturn and a token that points to the end + of the result set. If no events are returned then the end of the + stream has been reached (i.e. there are no events between + `from_token` and `to_token`), or `limit` is zero. + """ + + assert int(limit) >= 0 + + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + args = [False, room_id] + if direction == "b": + order = "DESC" + else: + order = "ASC" + + bounds = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=from_token, + to_token=to_token, + engine=self.database_engine, + ) + + filter_clause, filter_args = filter_to_clause(event_filter) + + if filter_clause: + bounds += " AND " + filter_clause + args.extend(filter_args) + + args.append(int(limit)) + + select_keywords = "SELECT" + join_clause = "" + if event_filter and event_filter.labels: + # If we're not filtering on a label, then joining on event_labels will + # return as many row for a single event as the number of labels it has. To + # avoid this, only join if we're filtering on at least one label. + join_clause = """ + LEFT JOIN event_labels + USING (event_id, room_id, topological_ordering) + """ + if len(event_filter.labels) > 1: + # Using DISTINCT in this SELECT query is quite expensive, because it + # requires the engine to sort on the entire (not limited) result set, + # i.e. the entire events table. We only need to use it when we're + # filtering on more than two labels, because that's the only scenario + # in which we can possibly to get multiple times the same event ID in + # the results. + select_keywords += "DISTINCT" + + sql = """ + %(select_keywords)s event_id, topological_ordering, stream_ordering + FROM events + %(join_clause)s + WHERE outlier = ? AND room_id = ? AND %(bounds)s + ORDER BY topological_ordering %(order)s, + stream_ordering %(order)s LIMIT ? + """ % { + "select_keywords": select_keywords, + "join_clause": join_clause, + "bounds": bounds, + "order": order, + } + + txn.execute(sql, args) + + rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + + if rows: + topo = rows[-1].topological_ordering + toke = rows[-1].stream_ordering + if direction == "b": + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + toke -= 1 + next_token = RoomStreamToken(topo, toke) + else: + # TODO (erikj): We should work out what to do here instead. + next_token = to_token if to_token else from_token + + return rows, str(next_token) + + @defer.inlineCallbacks + def paginate_room_events( + self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None + ): + """Returns list of events before or after a given token. + + Args: + room_id (str) + from_key (str): The token used to stream from + to_key (str|None): A token which if given limits the results to + only those before + direction(char): Either 'b' or 'f' to indicate whether we are + paginating forwards or backwards from `from_key`. + limit (int): The maximum number of events to return. + event_filter (Filter|None): If provided filters the events to + those that match the filter. + + Returns: + tuple[list[FrozenEvent], str]: Returns the results as a list of + events and a token that points to the end of the result set. If no + events are returned then the end of the stream has been reached + (i.e. there are no events between `from_key` and `to_key`). + """ + + from_key = RoomStreamToken.parse(from_key) + if to_key: + to_key = RoomStreamToken.parse(to_key) + + rows, token = yield self.db_pool.runInteraction( + "paginate_room_events", + self._paginate_room_events_txn, + room_id, + from_key, + to_key, + direction, + limit, + event_filter, + ) + + events = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(events, rows) + + return (events, token) + + +class StreamStore(StreamWorkerStore): + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py new file mode 100644 index 0000000000..eedd2d96c3 --- /dev/null +++ b/synapse/storage/databases/main/tags.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Tuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.storage._base import db_to_json +from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + + +class TagsWorkerStore(AccountDataWorkerStore): + @cached() + def get_tags_for_user(self, user_id): + """Get all the tags for a user. + + + Args: + user_id(str): The user to get the tags for. + Returns: + A deferred dict mapping from room_id strings to dicts mapping from + tag strings to tag content. + """ + + deferred = self.db_pool.simple_select_list( + "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] + ) + + @deferred.addCallback + def tags_by_room(rows): + tags_by_room = {} + for row in rows: + room_tags = tags_by_room.setdefault(row["room_id"], {}) + room_tags[row["tag"]] = db_to_json(row["content"]) + return tags_by_room + + return deferred + + async def get_all_updated_tags( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for tags replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_updated_tags_txn(txn): + sql = ( + "SELECT stream_id, user_id, room_id" + " FROM room_tags_revisions as r" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + tag_ids = await self.db_pool.runInteraction( + "get_all_updated_tags", get_all_updated_tags_txn + ) + + def get_tag_content(txn, tag_ids): + sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" + results = [] + for stream_id, user_id, room_id in tag_ids: + txn.execute(sql, (user_id, room_id)) + tags = [] + for tag, content in txn: + tags.append(json.dumps(tag) + ":" + content) + tag_json = "{" + ",".join(tags) + "}" + results.append((stream_id, (user_id, room_id, tag_json))) + + return results + + batch_size = 50 + results = [] + for i in range(0, len(tag_ids), batch_size): + tags = await self.db_pool.runInteraction( + "get_all_updated_tag_content", + get_tag_content, + tag_ids[i : i + batch_size], + ) + results.extend(tags) + + limited = False + upto_token = current_id + if len(results) >= limit: + upto_token = results[-1][0] + limited = True + + return results, upto_token, limited + + @defer.inlineCallbacks + def get_updated_tags(self, user_id, stream_id): + """Get all the tags for the rooms where the tags have changed since the + given version + + Args: + user_id(str): The user to get the tags for. + stream_id(int): The earliest update to get for the user. + Returns: + A deferred dict mapping from room_id strings to lists of tag + strings for all the rooms that changed since the stream_id token. + """ + + def get_updated_tags_txn(txn): + sql = ( + "SELECT room_id from room_tags_revisions" + " WHERE user_id = ? AND stream_id > ?" + ) + txn.execute(sql, (user_id, stream_id)) + room_ids = [row[0] for row in txn] + return room_ids + + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(stream_id) + ) + if not changed: + return {} + + room_ids = yield self.db_pool.runInteraction( + "get_updated_tags", get_updated_tags_txn + ) + + results = {} + if room_ids: + tags_by_room = yield self.get_tags_for_user(user_id) + for room_id in room_ids: + results[room_id] = tags_by_room.get(room_id, {}) + + return results + + def get_tags_for_room(self, user_id, room_id): + """Get all the tags for the given room + Args: + user_id(str): The user to get tags for + room_id(str): The room to get tags for + Returns: + A deferred list of string tags. + """ + return self.db_pool.simple_select_list( + table="room_tags", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcols=("tag", "content"), + desc="get_tags_for_room", + ).addCallback( + lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows} + ) + + +class TagsStore(TagsWorkerStore): + @defer.inlineCallbacks + def add_tag_to_room(self, user_id, room_id, tag, content): + """Add a tag to a room for a user. + Args: + user_id(str): The user to add a tag for. + room_id(str): The room to add a tag for. + tag(str): The tag name to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the tag has been added. + """ + content_json = json.dumps(content) + + def add_tag_txn(txn, next_id): + self.db_pool.simple_upsert_txn( + txn, + table="room_tags", + keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, + values={"content": content_json}, + ) + self._update_revision_txn(txn, user_id, room_id, next_id) + + with self._account_data_id_gen.get_next() as next_id: + yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) + + self.get_tags_for_user.invalidate((user_id,)) + + result = self._account_data_id_gen.get_current_token() + return result + + @defer.inlineCallbacks + def remove_tag_from_room(self, user_id, room_id, tag): + """Remove a tag from a room for a user. + Returns: + A deferred that completes once the tag has been removed + """ + + def remove_tag_txn(txn, next_id): + sql = ( + "DELETE FROM room_tags " + " WHERE user_id = ? AND room_id = ? AND tag = ?" + ) + txn.execute(sql, (user_id, room_id, tag)) + self._update_revision_txn(txn, user_id, room_id, next_id) + + with self._account_data_id_gen.get_next() as next_id: + yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) + + self.get_tags_for_user.invalidate((user_id,)) + + result = self._account_data_id_gen.get_current_token() + return result + + def _update_revision_txn(self, txn, user_id, room_id, next_id): + """Update the latest revision of the tags for the given user and room. + + Args: + txn: The database cursor + user_id(str): The ID of the user. + room_id(str): The ID of the room. + next_id(int): The the revision to advance to. + """ + + txn.call_after( + self._account_data_stream_cache.entity_has_changed, user_id, next_id + ) + + # Note: This is only here for backwards compat to allow admins to + # roll back to a previous Synapse version. Next time we update the + # database version we can remove this table. + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + + update_sql = ( + "UPDATE room_tags_revisions" + " SET stream_id = ?" + " WHERE user_id = ?" + " AND room_id = ?" + ) + txn.execute(update_sql, (next_id, user_id, room_id)) + + if txn.rowcount == 0: + insert_sql = ( + "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)" + " VALUES (?, ?, ?)" + ) + try: + txn.execute(insert_sql, (user_id, room_id, next_id)) + except self.database_engine.module.IntegrityError: + # Ignore insertion errors. It doesn't matter if the row wasn't + # inserted because if two updates happend concurrently the one + # with the higher stream_id will not be reported to a client + # unless the previous update has completed. It doesn't matter + # which stream_id ends up in the table, as long as it is higher + # than the id that the client has. + pass diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py new file mode 100644 index 0000000000..8804c0e4ac --- /dev/null +++ b/synapse/storage/databases/main/transactions.py @@ -0,0 +1,269 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple + +from canonicaljson import encode_canonical_json + +from twisted.internet import defer + +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.util.caches.expiringcache import ExpiringCache + +db_binary_type = memoryview + +logger = logging.getLogger(__name__) + + +_TransactionRow = namedtuple( + "_TransactionRow", + ("id", "transaction_id", "destination", "ts", "response_code", "response_json"), +) + +_UpdateTransactionRow = namedtuple( + "_TransactionRow", ("response_code", "response_json") +) + +SENTINEL = object() + + +class TransactionStore(SQLBaseStore): + """A collection of queries for handling PDUs. + """ + + def __init__(self, database: DatabasePool, db_conn, hs): + super(TransactionStore, self).__init__(database, db_conn, hs) + + self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) + + self._destination_retry_cache = ExpiringCache( + cache_name="get_destination_retry_timings", + clock=self._clock, + expiry_ms=5 * 60 * 1000, + ) + + def get_received_txn_response(self, transaction_id, origin): + """For an incoming transaction from a given origin, check if we have + already responded to it. If so, return the response code and response + body (as a dict). + + Args: + transaction_id (str) + origin(str) + + Returns: + tuple: None if we have not previously responded to + this transaction or a 2-tuple of (int, dict) + """ + + return self.db_pool.runInteraction( + "get_received_txn_response", + self._get_received_txn_response, + transaction_id, + origin, + ) + + def _get_received_txn_response(self, txn, transaction_id, origin): + result = self.db_pool.simple_select_one_txn( + txn, + table="received_transactions", + keyvalues={"transaction_id": transaction_id, "origin": origin}, + retcols=( + "transaction_id", + "origin", + "ts", + "response_code", + "response_json", + "has_been_referenced", + ), + allow_none=True, + ) + + if result and result["response_code"]: + return result["response_code"], db_to_json(result["response_json"]) + + else: + return None + + def set_received_txn_response(self, transaction_id, origin, code, response_dict): + """Persist the response we returened for an incoming transaction, and + should return for subsequent transactions with the same transaction_id + and origin. + + Args: + txn + transaction_id (str) + origin (str) + code (int) + response_json (str) + """ + + return self.db_pool.simple_insert( + table="received_transactions", + values={ + "transaction_id": transaction_id, + "origin": origin, + "response_code": code, + "response_json": db_binary_type(encode_canonical_json(response_dict)), + "ts": self._clock.time_msec(), + }, + or_ignore=True, + desc="set_received_txn_response", + ) + + @defer.inlineCallbacks + def get_destination_retry_timings(self, destination): + """Gets the current retry timings (if any) for a given destination. + + Args: + destination (str) + + Returns: + None if not retrying + Otherwise a dict for the retry scheme + """ + + result = self._destination_retry_cache.get(destination, SENTINEL) + if result is not SENTINEL: + return result + + result = yield self.db_pool.runInteraction( + "get_destination_retry_timings", + self._get_destination_retry_timings, + destination, + ) + + # We don't hugely care about race conditions between getting and + # invalidating the cache, since we time out fairly quickly anyway. + self._destination_retry_cache[destination] = result + return result + + def _get_destination_retry_timings(self, txn, destination): + result = self.db_pool.simple_select_one_txn( + txn, + table="destinations", + keyvalues={"destination": destination}, + retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), + allow_none=True, + ) + + if result and result["retry_last_ts"] > 0: + return result + else: + return None + + def set_destination_retry_timings( + self, destination, failure_ts, retry_last_ts, retry_interval + ): + """Sets the current retry timings for a given destination. + Both timings should be zero if retrying is no longer occuring. + + Args: + destination (str) + failure_ts (int|None) - when the server started failing (ms since epoch) + retry_last_ts (int) - time of last retry attempt in unix epoch ms + retry_interval (int) - how long until next retry in ms + """ + + self._destination_retry_cache.pop(destination, None) + return self.db_pool.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings, + destination, + failure_ts, + retry_last_ts, + retry_interval, + ) + + def _set_destination_retry_timings( + self, txn, destination, failure_ts, retry_last_ts, retry_interval + ): + + if self.database_engine.can_native_upsert: + # Upsert retry time interval if retry_interval is zero (i.e. we're + # resetting it) or greater than the existing retry interval. + + sql = """ + INSERT INTO destinations ( + destination, failure_ts, retry_last_ts, retry_interval + ) + VALUES (?, ?, ?, ?) + ON CONFLICT (destination) DO UPDATE SET + failure_ts = EXCLUDED.failure_ts, + retry_last_ts = EXCLUDED.retry_last_ts, + retry_interval = EXCLUDED.retry_interval + WHERE + EXCLUDED.retry_interval = 0 + OR destinations.retry_interval < EXCLUDED.retry_interval + """ + + txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) + + return + + self.database_engine.lock_table(txn, "destinations") + + # We need to be careful here as the data may have changed from under us + # due to a worker setting the timings. + + prev_row = self.db_pool.simple_select_one_txn( + txn, + table="destinations", + keyvalues={"destination": destination}, + retcols=("failure_ts", "retry_last_ts", "retry_interval"), + allow_none=True, + ) + + if not prev_row: + self.db_pool.simple_insert_txn( + txn, + table="destinations", + values={ + "destination": destination, + "failure_ts": failure_ts, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + }, + ) + elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: + self.db_pool.simple_update_one_txn( + txn, + "destinations", + keyvalues={"destination": destination}, + updatevalues={ + "failure_ts": failure_ts, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + }, + ) + + def _start_cleanup_transactions(self): + return run_as_background_process( + "cleanup_transactions", self._cleanup_transactions + ) + + def _cleanup_transactions(self): + now = self._clock.time_msec() + month_ago = now - 30 * 24 * 60 * 60 * 1000 + + def _cleanup_transactions_txn(txn): + txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + + return self.db_pool.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py new file mode 100644 index 0000000000..37276f73f8 --- /dev/null +++ b/synapse/storage/databases/main/ui_auth.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Union + +import attr +from canonicaljson import json + +from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict +from synapse.util import stringutils as stringutils + + +@attr.s +class UIAuthSessionData: + session_id = attr.ib(type=str) + # The dictionary from the client root level, not the 'auth' key. + clientdict = attr.ib(type=JsonDict) + # The URI and method the session was intiatied with. These are checked at + # each stage of the authentication to ensure that the asked for operation + # has not changed. + uri = attr.ib(type=str) + method = attr.ib(type=str) + # A string description of the operation that the current authentication is + # authorising. + description = attr.ib(type=str) + + +class UIAuthWorkerStore(SQLBaseStore): + """ + Manage user interactive authentication sessions. + """ + + async def create_ui_auth_session( + self, clientdict: JsonDict, uri: str, method: str, description: str, + ) -> UIAuthSessionData: + """ + Creates a new user interactive authentication session. + + The session can be used to track the stages necessary to authenticate a + user across multiple HTTP requests. + + Args: + clientdict: + The dictionary from the client root level, not the 'auth' key. + uri: + The URI this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + method: + The method this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + description: + A string description of the operation that the current + authentication is authorising. + Returns: + The newly created session. + Raises: + StoreError if a unique session ID cannot be generated. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + # autogen a session ID and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + while attempts < 5: + session_id = stringutils.random_string(24) + + try: + await self.db_pool.simple_insert( + table="ui_auth_sessions", + values={ + "session_id": session_id, + "clientdict": clientdict_json, + "uri": uri, + "method": method, + "description": description, + "serverdict": "{}", + "creation_time": self.hs.get_clock().time_msec(), + }, + desc="create_ui_auth_session", + ) + return UIAuthSessionData( + session_id, clientdict, uri, method, description + ) + except self.db_pool.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a session ID.") + + async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: + """Retrieve a UI auth session. + + Args: + session_id: The ID of the session. + Returns: + A dict containing the device information. + Raises: + StoreError if the session is not found. + """ + result = await self.db_pool.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("clientdict", "uri", "method", "description"), + desc="get_ui_auth_session", + ) + + result["clientdict"] = db_to_json(result["clientdict"]) + + return UIAuthSessionData(session_id, **result) + + async def mark_ui_auth_stage_complete( + self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict], + ): + """ + Mark a session stage as completed. + + Args: + session_id: The ID of the corresponding session. + stage_type: The completed stage type. + result: The result of the stage verification. + Raises: + StoreError if the session cannot be found. + """ + # Add (or update) the results of the current stage to the database. + # + # Note that we need to allow for the same stage to complete multiple + # times here so that registration is idempotent. + try: + await self.db_pool.simple_upsert( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id, "stage_type": stage_type}, + values={"result": json.dumps(result)}, + desc="mark_ui_auth_stage_complete", + ) + except self.db_pool.engine.module.IntegrityError: + raise StoreError(400, "Unknown session ID: %s" % (session_id,)) + + async def get_completed_ui_auth_stages( + self, session_id: str + ) -> Dict[str, Union[str, bool, JsonDict]]: + """ + Retrieve the completed stages of a UI authentication session. + + Args: + session_id: The ID of the session. + Returns: + The completed stages mapped to the result of the verification of + that auth-type. + """ + results = {} + for row in await self.db_pool.simple_select_list( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id}, + retcols=("stage_type", "result"), + desc="get_completed_ui_auth_stages", + ): + results[row["stage_type"]] = db_to_json(row["result"]) + + return results + + async def set_ui_auth_clientdict( + self, session_id: str, clientdict: JsonDict + ) -> None: + """ + Store an updated clientdict for a given session ID. + + Args: + session_id: The ID of this session as returned from check_auth + clientdict: + The dictionary from the client root level, not the 'auth' key. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + await self.db_pool.simple_update_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"clientdict": clientdict_json}, + desc="set_ui_auth_client_dict", + ) + + async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): + """ + Store a key-value pair into the sessions data associated with this + request. This data is stored server-side and cannot be modified by + the client. + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + value: The data to store + Raises: + StoreError if the session cannot be found. + """ + await self.db_pool.runInteraction( + "set_ui_auth_session_data", + self._set_ui_auth_session_data_txn, + session_id, + key, + value, + ) + + def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + # Get the current value. + result = self.db_pool.simple_select_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + ) + + # Update it and add it back to the database. + serverdict = db_to_json(result["serverdict"]) + serverdict[key] = value + + self.db_pool.simple_update_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"serverdict": json.dumps(serverdict)}, + ) + + async def get_ui_auth_session_data( + self, session_id: str, key: str, default: Optional[Any] = None + ) -> Any: + """ + Retrieve data stored with set_session_data + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + default: Value to return if the key has not been set + Raises: + StoreError if the session cannot be found. + """ + result = await self.db_pool.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + desc="get_ui_auth_session_data", + ) + + serverdict = db_to_json(result["serverdict"]) + + return serverdict.get(key, default) + + +class UIAuthStore(UIAuthWorkerStore): + def delete_old_ui_auth_sessions(self, expiration_time: int): + """ + Remove sessions which were last used earlier than the expiration time. + + Args: + expiration_time: The latest time that is still considered valid. + This is an epoch time in milliseconds. + + """ + return self.db_pool.runInteraction( + "delete_old_ui_auth_sessions", + self._delete_old_ui_auth_sessions_txn, + expiration_time, + ) + + def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + # Get the expired sessions. + sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" + txn.execute(sql, [expiration_time]) + session_ids = [r[0] for r in txn.fetchall()] + + # Delete the corresponding completed credentials. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + + # Finally, delete the sessions. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py new file mode 100644 index 0000000000..d73a8e8ab9 --- /dev/null +++ b/synapse/storage/databases/main/user_directory.py @@ -0,0 +1,847 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.state import StateFilter +from synapse.storage.databases.main.state_deltas import StateDeltasStore +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import get_domain_from_id, get_localpart_from_id +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + + +TEMP_TABLE = "_temp_populate_user_directory" + + +class UserDirectoryBackgroundUpdateStore(StateDeltasStore): + + # How many records do we calculate before sending it to + # add_users_who_share_private_rooms? + SHARE_PRIVATE_WORKING_SET = 500 + + def __init__(self, database: DatabasePool, db_conn, hs): + super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.server_name = hs.hostname + + self.db_pool.updates.register_background_update_handler( + "populate_user_directory_createtables", + self._populate_user_directory_createtables, + ) + self.db_pool.updates.register_background_update_handler( + "populate_user_directory_process_rooms", + self._populate_user_directory_process_rooms, + ) + self.db_pool.updates.register_background_update_handler( + "populate_user_directory_process_users", + self._populate_user_directory_process_users, + ) + self.db_pool.updates.register_background_update_handler( + "populate_user_directory_cleanup", self._populate_user_directory_cleanup + ) + + @defer.inlineCallbacks + def _populate_user_directory_createtables(self, progress, batch_size): + + # Get all the rooms that we want to process. + def _make_staging_area(txn): + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)" + ) + txn.execute(sql) + + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_position(position TEXT NOT NULL)" + ) + txn.execute(sql) + + # Get rooms we want to process from the database + sql = """ + SELECT room_id, count(*) FROM current_state_events + GROUP BY room_id + """ + txn.execute(sql) + rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] + self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + del rooms + + # If search all users is on, get all the users we want to add. + if self.hs.config.user_directory_search_all_users: + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_users(user_id TEXT NOT NULL)" + ) + txn.execute(sql) + + txn.execute("SELECT name FROM users") + users = [{"user_id": x[0]} for x in txn.fetchall()] + + self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + + new_pos = yield self.get_max_stream_id_in_current_state_deltas() + yield self.db_pool.runInteraction( + "populate_user_directory_temp_build", _make_staging_area + ) + yield self.db_pool.simple_insert( + TEMP_TABLE + "_position", {"position": new_pos} + ) + + yield self.db_pool.updates._end_background_update( + "populate_user_directory_createtables" + ) + return 1 + + @defer.inlineCallbacks + def _populate_user_directory_cleanup(self, progress, batch_size): + """ + Update the user directory stream position, then clean up the old tables. + """ + position = yield self.db_pool.simple_select_one_onecol( + TEMP_TABLE + "_position", None, "position" + ) + yield self.update_user_directory_stream_pos(position) + + def _delete_staging_area(txn): + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") + + yield self.db_pool.runInteraction( + "populate_user_directory_cleanup", _delete_staging_area + ) + + yield self.db_pool.updates._end_background_update( + "populate_user_directory_cleanup" + ) + return 1 + + @defer.inlineCallbacks + def _populate_user_directory_process_rooms(self, progress, batch_size): + """ + Args: + progress (dict) + batch_size (int): Maximum number of state events to process + per cycle. + """ + state = self.hs.get_state_handler() + + # If we don't have progress filed, delete everything. + if not progress: + yield self.delete_all_from_user_dir() + + def _get_next_batch(txn): + # Only fetch 250 rooms, so we don't fetch too many at once, even + # if those 250 rooms have less than batch_size state events. + sql = """ + SELECT room_id, events FROM %s + ORDER BY events DESC + LIMIT 250 + """ % ( + TEMP_TABLE + "_rooms", + ) + txn.execute(sql) + rooms_to_work_on = txn.fetchall() + + if not rooms_to_work_on: + return None + + # Get how many are left to process, so we can give status on how + # far we are in processing + txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") + progress["remaining"] = txn.fetchone()[0] + + return rooms_to_work_on + + rooms_to_work_on = yield self.db_pool.runInteraction( + "populate_user_directory_temp_read", _get_next_batch + ) + + # No more rooms -- complete the transaction. + if not rooms_to_work_on: + yield self.db_pool.updates._end_background_update( + "populate_user_directory_process_rooms" + ) + return 1 + + logger.debug( + "Processing the next %d rooms of %d remaining" + % (len(rooms_to_work_on), progress["remaining"]) + ) + + processed_event_count = 0 + + for room_id, event_count in rooms_to_work_on: + is_in_room = yield self.is_host_joined(room_id, self.server_name) + + if is_in_room: + is_public = yield self.is_room_world_readable_or_publicly_joinable( + room_id + ) + + users_with_profile = yield defer.ensureDeferred( + state.get_current_users_in_room(room_id) + ) + user_ids = set(users_with_profile) + + # Update each user in the user directory. + for user_id, profile in users_with_profile.items(): + yield self.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url + ) + + to_insert = set() + + if is_public: + for user_id in user_ids: + if self.get_if_app_services_interested_in_user(user_id): + continue + + to_insert.add(user_id) + + if to_insert: + yield self.add_users_in_public_rooms(room_id, to_insert) + to_insert.clear() + else: + for user_id in user_ids: + if not self.hs.is_mine_id(user_id): + continue + + if self.get_if_app_services_interested_in_user(user_id): + continue + + for other_user_id in user_ids: + if user_id == other_user_id: + continue + + user_set = (user_id, other_user_id) + to_insert.add(user_set) + + # If it gets too big, stop and write to the database + # to prevent storing too much in RAM. + if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET: + yield self.add_users_who_share_private_room( + room_id, to_insert + ) + to_insert.clear() + + if to_insert: + yield self.add_users_who_share_private_room(room_id, to_insert) + to_insert.clear() + + # We've finished a room. Delete it from the table. + yield self.db_pool.simple_delete_one( + TEMP_TABLE + "_rooms", {"room_id": room_id} + ) + # Update the remaining counter. + progress["remaining"] -= 1 + yield self.db_pool.runInteraction( + "populate_user_directory", + self.db_pool.updates._background_update_progress_txn, + "populate_user_directory_process_rooms", + progress, + ) + + processed_event_count += event_count + + if processed_event_count > batch_size: + # Don't process any more rooms, we've hit our batch size. + return processed_event_count + + return processed_event_count + + @defer.inlineCallbacks + def _populate_user_directory_process_users(self, progress, batch_size): + """ + If search_all_users is enabled, add all of the users to the user directory. + """ + if not self.hs.config.user_directory_search_all_users: + yield self.db_pool.updates._end_background_update( + "populate_user_directory_process_users" + ) + return 1 + + def _get_next_batch(txn): + sql = "SELECT user_id FROM %s LIMIT %s" % ( + TEMP_TABLE + "_users", + str(batch_size), + ) + txn.execute(sql) + users_to_work_on = txn.fetchall() + + if not users_to_work_on: + return None + + users_to_work_on = [x[0] for x in users_to_work_on] + + # Get how many are left to process, so we can give status on how + # far we are in processing + sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users" + txn.execute(sql) + progress["remaining"] = txn.fetchone()[0] + + return users_to_work_on + + users_to_work_on = yield self.db_pool.runInteraction( + "populate_user_directory_temp_read", _get_next_batch + ) + + # No more users -- complete the transaction. + if not users_to_work_on: + yield self.db_pool.updates._end_background_update( + "populate_user_directory_process_users" + ) + return 1 + + logger.debug( + "Processing the next %d users of %d remaining" + % (len(users_to_work_on), progress["remaining"]) + ) + + for user_id in users_to_work_on: + profile = yield self.get_profileinfo(get_localpart_from_id(user_id)) + yield self.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url + ) + + # We've finished processing a user. Delete it from the table. + yield self.db_pool.simple_delete_one( + TEMP_TABLE + "_users", {"user_id": user_id} + ) + # Update the remaining counter. + progress["remaining"] -= 1 + yield self.db_pool.runInteraction( + "populate_user_directory", + self.db_pool.updates._background_update_progress_txn, + "populate_user_directory_process_users", + progress, + ) + + return len(users_to_work_on) + + @defer.inlineCallbacks + def is_room_world_readable_or_publicly_joinable(self, room_id): + """Check if the room is either world_readable or publically joinable + """ + + # Create a state filter that only queries join and history state event + types_to_filter = ( + (EventTypes.JoinRules, ""), + (EventTypes.RoomHistoryVisibility, ""), + ) + + current_state_ids = yield self.get_filtered_current_state_ids( + room_id, StateFilter.from_types(types_to_filter) + ) + + join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) + if join_rules_id: + join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) + if join_rule_ev: + if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: + return True + + hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_id: + hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) + if hist_vis_ev: + if hist_vis_ev.content.get("history_visibility") == "world_readable": + return True + + return False + + def update_profile_in_user_dir(self, user_id, display_name, avatar_url): + """ + Update or add a user's profile in the user directory. + """ + + def _update_profile_in_user_dir_txn(txn): + new_entry = self.db_pool.simple_upsert_txn( + txn, + table="user_directory", + keyvalues={"user_id": user_id}, + values={"display_name": display_name, "avatar_url": avatar_url}, + lock=False, # We're only inserter + ) + + if isinstance(self.database_engine, PostgresEngine): + # We weight the localpart most highly, then display name and finally + # server name + if self.database_engine.can_native_upsert: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector + """ + txn.execute( + sql, + ( + user_id, + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + ), + ) + else: + # TODO: Remove this code after we've bumped the minimum version + # of postgres to always support upserts, so we can get rid of + # `new_entry` usage + if new_entry is True: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + txn.execute( + sql, + ( + user_id, + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + ), + ) + elif new_entry is False: + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + txn.execute( + sql, + ( + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + user_id, + ), + ) + else: + raise RuntimeError( + "upsert returned None when 'can_native_upsert' is False" + ) + elif isinstance(self.database_engine, Sqlite3Engine): + value = "%s %s" % (user_id, display_name) if display_name else user_id + self.db_pool.simple_upsert_txn( + txn, + table="user_directory_search", + keyvalues={"user_id": user_id}, + values={"value": value}, + lock=False, # We're only inserter + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + + return self.db_pool.runInteraction( + "update_profile_in_user_dir", _update_profile_in_user_dir_txn + ) + + def add_users_who_share_private_room(self, room_id, user_id_tuples): + """Insert entries into the users_who_share_private_rooms table. The first + user should be a local user. + + Args: + room_id (str) + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + + def _add_users_who_share_room_txn(txn): + self.db_pool.simple_upsert_many_txn( + txn, + table="users_who_share_private_rooms", + key_names=["user_id", "other_user_id", "room_id"], + key_values=[ + (user_id, other_user_id, room_id) + for user_id, other_user_id in user_id_tuples + ], + value_names=(), + value_values=None, + ) + + return self.db_pool.runInteraction( + "add_users_who_share_room", _add_users_who_share_room_txn + ) + + def add_users_in_public_rooms(self, room_id, user_ids): + """Insert entries into the users_who_share_private_rooms table. The first + user should be a local user. + + Args: + room_id (str) + user_ids (list[str]) + """ + + def _add_users_in_public_rooms_txn(txn): + + self.db_pool.simple_upsert_many_txn( + txn, + table="users_in_public_rooms", + key_names=["user_id", "room_id"], + key_values=[(user_id, room_id) for user_id in user_ids], + value_names=(), + value_values=None, + ) + + return self.db_pool.runInteraction( + "add_users_in_public_rooms", _add_users_in_public_rooms_txn + ) + + def delete_all_from_user_dir(self): + """Delete the entire user directory + """ + + def _delete_all_from_user_dir_txn(txn): + txn.execute("DELETE FROM user_directory") + txn.execute("DELETE FROM user_directory_search") + txn.execute("DELETE FROM users_in_public_rooms") + txn.execute("DELETE FROM users_who_share_private_rooms") + txn.call_after(self.get_user_in_directory.invalidate_all) + + return self.db_pool.runInteraction( + "delete_all_from_user_dir", _delete_all_from_user_dir_txn + ) + + @cached() + def get_user_in_directory(self, user_id): + return self.db_pool.simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("display_name", "avatar_url"), + allow_none=True, + desc="get_user_in_directory", + ) + + def update_user_directory_stream_pos(self, stream_id): + return self.db_pool.simple_update_one( + table="user_directory_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_user_directory_stream_pos", + ) + + +class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): + + # How many records do we calculate before sending it to + # add_users_who_share_private_rooms? + SHARE_PRIVATE_WORKING_SET = 500 + + def __init__(self, database: DatabasePool, db_conn, hs): + super(UserDirectoryStore, self).__init__(database, db_conn, hs) + + def remove_from_user_dir(self, user_id): + def _remove_from_user_dir_txn(txn): + self.db_pool.simple_delete_txn( + txn, table="user_directory", keyvalues={"user_id": user_id} + ) + self.db_pool.simple_delete_txn( + txn, table="user_directory_search", keyvalues={"user_id": user_id} + ) + self.db_pool.simple_delete_txn( + txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} + ) + self.db_pool.simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"other_user_id": user_id}, + ) + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + + return self.db_pool.runInteraction( + "remove_from_user_dir", _remove_from_user_dir_txn + ) + + @defer.inlineCallbacks + def get_users_in_dir_due_to_room(self, room_id): + """Get all user_ids that are in the room directory because they're + in the given room_id + """ + user_ids_share_pub = yield self.db_pool.simple_select_onecol( + table="users_in_public_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids_share_priv = yield self.db_pool.simple_select_onecol( + table="users_who_share_private_rooms", + keyvalues={"room_id": room_id}, + retcol="other_user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids = set(user_ids_share_pub) + user_ids.update(user_ids_share_priv) + + return user_ids + + def remove_user_who_share_room(self, user_id, room_id): + """ + Deletes entries in the users_who_share_*_rooms table. The first + user should be a local user. + + Args: + user_id (str) + room_id (str) + """ + + def _remove_user_who_share_room_txn(txn): + self.db_pool.simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"user_id": user_id, "room_id": room_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"other_user_id": user_id, "room_id": room_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="users_in_public_rooms", + keyvalues={"user_id": user_id, "room_id": room_id}, + ) + + return self.db_pool.runInteraction( + "remove_user_who_share_room", _remove_user_who_share_room_txn + ) + + @defer.inlineCallbacks + def get_user_dir_rooms_user_is_in(self, user_id): + """ + Returns the rooms that a user is in. + + Args: + user_id(str): Must be a local user + + Returns: + list: user_id + """ + rows = yield self.db_pool.simple_select_onecol( + table="users_who_share_private_rooms", + keyvalues={"user_id": user_id}, + retcol="room_id", + desc="get_rooms_user_is_in", + ) + + pub_rows = yield self.db_pool.simple_select_onecol( + table="users_in_public_rooms", + keyvalues={"user_id": user_id}, + retcol="room_id", + desc="get_rooms_user_is_in", + ) + + users = set(pub_rows) + users.update(rows) + return list(users) + + @defer.inlineCallbacks + def get_rooms_in_common_for_users(self, user_id, other_user_id): + """Given two user_ids find out the list of rooms they share. + """ + sql = """ + SELECT room_id FROM ( + SELECT c.room_id FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (event_id) + WHERE type = 'm.room.member' + AND m.membership = 'join' + AND state_key = ? + ) AS f1 INNER JOIN ( + SELECT c.room_id FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (event_id) + WHERE type = 'm.room.member' + AND m.membership = 'join' + AND state_key = ? + ) f2 USING (room_id) + """ + + rows = yield self.db_pool.execute( + "get_rooms_in_common_for_users", None, sql, user_id, other_user_id + ) + + return [room_id for room_id, in rows] + + def get_user_directory_stream_pos(self): + return self.db_pool.simple_select_one_onecol( + table="user_directory_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="get_user_directory_stream_pos", + ) + + @defer.inlineCallbacks + def search_user_dir(self, user_id, search_term, limit): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": , # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": , + "display_name": , + "avatar_url": + } + ] + } + """ + + if self.hs.config.user_directory_search_all_users: + join_args = (user_id,) + where_clause = "user_id != ?" + else: + join_args = (user_id,) + where_clause = """ + ( + EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id) + OR EXISTS ( + SELECT 1 FROM users_who_share_private_rooms + WHERE user_id = ? AND other_user_id = t.user_id + ) + ) + """ + + if isinstance(self.database_engine, PostgresEngine): + full_query, exact_query, prefix_query = _parse_query_postgres(search_term) + + # We order by rank and then if they have profile info + # The ranking algorithm is hand tweaked for "best" results. Broadly + # the idea is we give a higher weight to exact matches. + # The array of numbers are the weights for the various part of the + # search: (domain, _, display name, localpart) + sql = """ + SELECT d.user_id AS user_id, display_name, avatar_url + FROM user_directory_search as t + INNER JOIN user_directory AS d USING (user_id) + WHERE + %s + AND vector @@ to_tsquery('english', ?) + ORDER BY + (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) + * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END) + * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END) + * ( + 3 * ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + + ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + ) + DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ % ( + where_clause, + ) + args = join_args + (full_query, exact_query, prefix_query, limit + 1) + elif isinstance(self.database_engine, Sqlite3Engine): + search_query = _parse_query_sqlite(search_term) + + sql = """ + SELECT d.user_id AS user_id, display_name, avatar_url + FROM user_directory_search as t + INNER JOIN user_directory AS d USING (user_id) + WHERE + %s + AND value MATCH ? + ORDER BY + rank(matchinfo(user_directory_search)) DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ % ( + where_clause, + ) + args = join_args + (search_query, limit + 1) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + results = yield self.db_pool.execute( + "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + ) + + limited = len(results) > limit + + return {"limited": limited, "results": results} + + +def _parse_query_sqlite(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + + We specifically add both a prefix and non prefix matching term so that + exact matches get ranked higher. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + return " & ".join("(%s* OR %s)" % (result, result) for result in results) + + +def _parse_query_postgres(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + both = " & ".join("(%s:* | %s)" % (result, result) for result in results) + exact = " & ".join("%s" % (result,) for result in results) + prefix = " & ".join("%s:*" % (result,) for result in results) + + return both, exact, prefix diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py new file mode 100644 index 0000000000..ab6cb2c1f6 --- /dev/null +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator + +from synapse.storage._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedList + + +class UserErasureWorkerStore(SQLBaseStore): + @cached() + def is_user_erased(self, user_id): + """ + Check if the given user id has requested erasure + + Args: + user_id (str): full user id to check + + Returns: + Deferred[bool]: True if the user has requested erasure + """ + return self.db_pool.simple_select_onecol( + table="erased_users", + keyvalues={"user_id": user_id}, + retcol="1", + desc="is_user_erased", + ).addCallback(operator.truth) + + @cachedList( + cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True + ) + def are_users_erased(self, user_ids): + """ + Checks which users in a list have requested erasure + + Args: + user_ids (iterable[str]): full user id to check + + Returns: + Deferred[dict[str, bool]]: + for each user, whether the user has requested erasure. + """ + # this serves the dual purpose of (a) making sure we can do len and + # iterate it multiple times, and (b) avoiding duplicates. + user_ids = tuple(set(user_ids)) + + rows = yield self.db_pool.simple_select_many_batch( + table="erased_users", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="are_users_erased", + ) + erased_users = {row["user_id"] for row in rows} + + res = {u: u in erased_users for u in user_ids} + return res + + +class UserErasureStore(UserErasureWorkerStore): + def mark_user_erased(self, user_id: str) -> None: + """Indicate that user_id wishes their message history to be erased. + + Args: + user_id: full user_id to be erased + """ + + def f(txn): + # first check if they are already in the list + txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) + if txn.fetchone(): + return + + # they are not already there: do the insert. + txn.execute("INSERT INTO erased_users (user_id) VALUES (?)", (user_id,)) + + self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) + + return self.db_pool.runInteraction("mark_user_erased", f) + + def mark_user_not_erased(self, user_id: str) -> None: + """Indicate that user_id is no longer erased. + + Args: + user_id: full user_id to be un-erased + """ + + def f(txn): + # first check if they are already in the list + txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) + if not txn.fetchone(): + return + + # They are there, delete them. + self.simple_delete_one_txn( + txn, "erased_users", keyvalues={"user_id": user_id} + ) + + self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) + + return self.db_pool.runInteraction("mark_user_not_erased", f) diff --git a/synapse/storage/databases/state/__init__.py b/synapse/storage/databases/state/__init__.py new file mode 100644 index 0000000000..c90d022899 --- /dev/null +++ b/synapse/storage/databases/state/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.databases.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py new file mode 100644 index 0000000000..1e2d584098 --- /dev/null +++ b/synapse/storage/databases/state/bg_updates.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) %s + ORDER BY type, state_key, state_group DESC + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql % (where_clause,), args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db_pool.updates.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.db_pool.updates.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state + ) + self.db_pool.updates.register_background_index_update( + self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, + index_name="state_groups_room_id_idx", + table="state_groups", + columns=["room_id"], + ) + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self.db_pool.execute( + "_background_deduplicate_state", + None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in range(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group), + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,), + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id), + ) + (prev_group,) = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group] + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group] + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value + for key, value in curr_state.items() + if prev_state.get(key, None) != value + } + + self.db_pool.simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + ) + + self.db_pool.simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self.db_pool.simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in delta_state.items() + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self.db_pool.updates._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.db_pool.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + ) + + if finished: + yield self.db_pool.updates._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) + + return result * BATCH_SIZE_SCALE_FACTOR + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + + yield self.db_pool.runWithConnection(reindex_txn) + + yield self.db_pool.updates._end_background_update( + self.STATE_GROUP_INDEX_UPDATE_NAME + ) + + return 1 diff --git a/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql new file mode 100644 index 0000000000..ae09fa0065 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS state_groups_state_tuple; diff --git a/synapse/storage/databases/state/schema/delta/30/state_stream.sql b/synapse/storage/databases/state/schema/delta/30/state_stream.sql new file mode 100644 index 0000000000..e85699e82e --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/30/state_stream.sql @@ -0,0 +1,33 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/* We used to create a table called current_state_resets, but this is no + * longer used and is removed in delta 54. + */ + +/* The outlier events that have aquired a state group typically through + * backfill. This is tracked separately to the events table, as assigning a + * state group change the position of the existing event in the stream + * ordering. + * However since a stream_ordering is assigned in persist_event for the + * (event, state) pair, we can use that stream_ordering to identify when + * the new state was assigned for the event. + */ +CREATE TABLE IF NOT EXISTS ex_outlier_stream( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL, + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL +); diff --git a/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql new file mode 100644 index 0000000000..1450313bfa --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY diff --git a/synapse/storage/databases/state/schema/delta/35/add_state_index.sql b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql new file mode 100644 index 0000000000..33980d02f0 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/databases/state/schema/delta/35/state.sql b/synapse/storage/databases/state/schema/delta/35/state.sql new file mode 100644 index 0000000000..0f1fa68a89 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/35/state.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_group_edges( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); diff --git a/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql new file mode 100644 index 0000000000..97e5067ef4 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('state_group_state_deduplication', '{}'); diff --git a/synapse/storage/databases/state/schema/delta/47/state_group_seq.py b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py new file mode 100644 index 0000000000..9fd1ccf6f7 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py @@ -0,0 +1,34 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + # if we already have some state groups, we want to start making new + # ones with a higher id. + cur.execute("SELECT max(id) FROM state_groups") + row = cur.fetchone() + + if row[0] is None: + start_val = 1 + else: + start_val = row[0] + 1 + + cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql new file mode 100644 index 0000000000..7916ef18b2 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('state_groups_room_id_idx', '{}'); diff --git a/synapse/storage/databases/state/schema/full_schemas/54/full.sql b/synapse/storage/databases/state/schema/full_schemas/54/full.sql new file mode 100644 index 0000000000..35f97d6b3d --- /dev/null +++ b/synapse/storage/databases/state/schema/full_schemas/54/full.sql @@ -0,0 +1,37 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_groups ( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_groups_state ( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_group_edges ( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges (state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group); +CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key); diff --git a/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres new file mode 100644 index 0000000000..fcd926c9fb --- /dev/null +++ b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE state_group_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py new file mode 100644 index 0000000000..7f104ad936 --- /dev/null +++ b/synapse/storage/databases/state/store.py @@ -0,0 +1,644 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple +from typing import Dict, Iterable, List, Set, Tuple + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.state import StateFilter +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator +from synapse.types import StateMap +from synapse.util.caches.descriptors import cached +from synapse.util.caches.dictionary_cache import DictionaryCache + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): + """A data store for fetching/storing state groups. + """ + + def __init__(self, database: DatabasePool, db_conn, hs): + super(StateGroupDataStore, self).__init__(database, db_conn, hs) + + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000, + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", 500000, + ) + + def get_max_state_group_txn(txn: Cursor): + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + return txn.fetchone()[0] + + self._state_group_seq_gen = build_sequence_generator( + self.database_engine, get_max_state_group_txn, "state_group_id_seq" + ) + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + def _get_state_group_delta_txn(txn): + prev_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self.db_pool.simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ) + + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.db_pool.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) + + async def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ) -> Dict[int, StateMap[str]]: + """Returns the state groups for a given set of groups from the + database, filtering on types of state events. + + Args: + groups: list of state group IDs to query + state_filter: The state filter used to fetch state + from the database. + Returns: + Dict of state group to state map. + """ + results = {} + + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] + for chunk in chunks: + res = await self.db_pool.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, + chunk, + state_filter, + ) + results.update(res) + + return results + + def _get_state_for_group_using_cache(self, cache, group, state_filter): + """Checks if group is in cache. See `_get_state_for_groups` + + Args: + cache(DictionaryCache): the state group cache to use + group(int): The state group to lookup + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + """ + is_all, known_absent, state_dict_ids = cache.get(group) + + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all + + # tracks whether any of our requested types are missing from the cache + missing_types = False + + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): + if key not in state_dict_ids and key not in known_absent: + missing_types = True + break + + return state_filter.filter_state(state_dict_ids), not missing_types + + async def _get_state_for_groups( + self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + ) -> Dict[int, StateMap[str]]: + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups: list of state groups for which we want + to get the state. + state_filter: The state filter used to fetch state + from the database. + Returns: + Dict of state group to state map. + """ + + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter + ) + + (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + return state + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = await self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in group_to_state_dict.items(): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + return state + + def _get_state_for_groups_using_cache( + self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter + ) -> Tuple[Dict[int, StateMap[str]], Set[int]]: + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups: list of state groups for which we want to get the state. + cache: the cache of group ids to state dicts which + we will pass through - either the normal state cache or the + specific members state cache. + state_filter: The state filter used to fetch state from the + database. + + Returns: + Tuple of dict of state_group_id to state map of entries in the + cache, and the state group ids either missing from the cache or + incomplete. + """ + results = {} + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids + + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups + + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): + """Inserts results from querying the database into the relevant cache. + + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ + + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) + + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() + + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in group_to_state_dict.items(): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in group_state_dict.items(): + if k[0] == EventTypes.Member: + state_dict_members[k] = v + else: + state_dict_non_members[k] = v + + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self._state_group_seq_gen.get_next_id_txn(txn) + + self.db_pool.simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self.db_pool.simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in delta_ids.items() + ], + ) + else: + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in current_state_ids.items() + ], + ) + + # Prefill the state group caches with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in current_state_ids.items() + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in current_state_ids.items() + if s[0] != EventTypes.Member + } + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_non_member_state_ids), + ) + + return state_group + + return self.db_pool.runInteraction("store_state_group", _store_state_group_txn) + + def purge_unreferenced_state_groups( + self, room_id: str, state_groups_to_delete + ) -> defer.Deferred: + """Deletes no longer referenced state groups and de-deltas any state + groups that reference them. + + Args: + room_id: The room the state groups belong to (must all be in the + same room). + state_groups_to_delete (Collection[int]): Set of all state groups + to delete. + """ + + return self.db_pool.runInteraction( + "purge_unreferenced_state_groups", + self._purge_unreferenced_state_groups, + room_id, + state_groups_to_delete, + ) + + def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + rows = self.db_pool.simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ) + + remaining_state_groups = { + row["state_group"] + for row in rows + if row["state_group"] not in state_groups_to_delete + } + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self.db_pool.simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self.db_pool.simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in curr_state.items() + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + + async def get_previous_state_groups( + self, state_groups: Iterable[int] + ) -> Dict[int, int]: + """Fetch the previous groups of the given state groups. + + Args: + state_groups + + Returns: + A mapping from state group to previous state group. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + desc="get_previous_state_groups", + ) + + return {row["state_group"]: row["prev_state_group"] for row in rows} + + def purge_room_state(self, room_id, state_groups_to_delete): + """Deletes all record of a room from state tables + + Args: + room_id (str): + state_groups_to_delete (list[int]): State groups to delete + """ + + return self.db_pool.runInteraction( + "purge_room_state", + self._purge_room_state_txn, + room_id, + state_groups_to_delete, + ) + + def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups_state", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + self.db_pool.simple_delete_many_txn( + txn, + table="state_group_edges", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups", + column="id", + iterable=state_groups_to_delete, + keyvalues={}, + ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 4a164834d9..f15b95e633 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -29,8 +29,8 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.data_stores import DataStores -from synapse.storage.data_stores.main.events import DeltaState +from synapse.storage.databases import Databases +from synapse.storage.databases.main.events import DeltaState from synapse.types import StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -179,7 +179,7 @@ class EventsPersistenceStorage(object): current state and forward extremity changes. """ - def __init__(self, hs, stores: DataStores): + def __init__(self, hs, stores: Databases): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same # store for now. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 9cc3b51fe6..1c5f305132 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -47,8 +47,8 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]): - """Prepares a database for usage. Will either create all necessary tables +def prepare_database(db_conn, database_engine, config, databases=["main", "state"]): + """Prepares a physical database for usage. Will either create all necessary tables or upgrade from an older schema version. If `config` is None then prepare_database will assert that no upgrade is @@ -60,8 +60,8 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already - data_stores (list[str]): The name of the data stores that will be used - with this database. Defaults to all data stores. + databases (list[str]): The name of the databases that will be used + with this physical database. Defaults to all databases. """ try: @@ -87,10 +87,10 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta upgraded, database_engine, config, - data_stores=data_stores, + databases=databases, ) else: - _setup_new_database(cur, database_engine, data_stores=data_stores) + _setup_new_database(cur, database_engine, databases=databases) # check if any of our configured dynamic modules want a database if config is not None: @@ -103,9 +103,9 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta raise -def _setup_new_database(cur, database_engine, data_stores): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas, including schemas from the given data +def _setup_new_database(cur, database_engine, databases): + """Sets up the physical database by finding a base set of "full schemas" and + then applying any necessary deltas, including schemas from the given data stores. The "full_schemas" directory has subdirectories named after versions. This @@ -138,8 +138,8 @@ def _setup_new_database(cur, database_engine, data_stores): Args: cur (Cursor): a database cursor database_engine (DatabaseEngine) - data_stores (list[str]): The names of the data stores to instantiate - on the given database. + databases (list[str]): The names of the databases to instantiate + on the given physical database. """ # We're about to set up a brand new database so we check that its @@ -176,13 +176,13 @@ def _setup_new_database(cur, database_engine, data_stores): directories.extend( os.path.join( dir_path, - "data_stores", - data_store, + "databases", + database, "schema", "full_schemas", str(max_current_ver), ) - for data_store in data_stores + for database in databases ) directory_entries = [] @@ -219,7 +219,7 @@ def _setup_new_database(cur, database_engine, data_stores): upgraded=False, database_engine=database_engine, config=None, - data_stores=data_stores, + databases=databases, is_empty=True, ) @@ -231,10 +231,10 @@ def _upgrade_existing_database( upgraded, database_engine, config, - data_stores, + databases, is_empty=False, ): - """Upgrades an existing database. + """Upgrades an existing physical database. Delta files can either be SQL stored in *.sql files, or python modules in *.py. @@ -285,8 +285,8 @@ def _upgrade_existing_database( config (synapse.config.homeserver.HomeServerConfig|None): None if we are initialising a blank database, otherwise the application config - data_stores (list[str]): The names of the data stores to instantiate - on the given database. + databases (list[str]): The names of the databases to instantiate + on the given physical database. is_empty (bool): Is this a blank database? I.e. do we need to run the upgrade portions of the delta scripts. """ @@ -303,8 +303,8 @@ def _upgrade_existing_database( # some of the deltas assume that config.server_name is set correctly, so now # is a good time to run the sanity check. - if not is_empty and "main" in data_stores: - from synapse.storage.data_stores.main import check_database_before_upgrade + if not is_empty and "main" in databases: + from synapse.storage.databases.main import check_database_before_upgrade check_database_before_upgrade(cur, database_engine, config) @@ -330,11 +330,9 @@ def _upgrade_existing_database( # First we find the directories to search in delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) directories = [delta_dir] - for data_store in data_stores: + for database in databases: directories.append( - os.path.join( - dir_path, "data_stores", data_store, "schema", "delta", str(v) - ) + os.path.join(dir_path, "databases", database, "schema", "delta", str(v)) ) # Used to check if we have any duplicate file names diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 787cebfbec..e2ddd01290 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -20,7 +20,7 @@ from typing import Dict, Set, Tuple from typing_extensions import Deque -from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.util.sequence import PostgresSequenceGenerator @@ -239,7 +239,7 @@ class MultiWriterIdGenerator: def __init__( self, db_conn, - db: Database, + db: DatabasePool, instance_name: str, table: str, instance_column: str, diff --git a/synmark/__init__.py b/synmark/__init__.py index afe4fad8cb..53698bd5ab 100644 --- a/synmark/__init__.py +++ b/synmark/__init__.py @@ -47,9 +47,9 @@ async def make_homeserver(reactor, config=None): stor = hs.get_datastore() # Run the database background updates. - if hasattr(stor.db.updates, "do_next_background_update"): - while not await stor.db.updates.has_completed_background_updates(): - await stor.db.updates.do_next_background_update(1) + if hasattr(stor.db_pool.updates, "do_next_background_update"): + while not await stor.db_pool.updates.has_completed_background_updates(): + await stor.db_pool.updates.do_next_background_update(1) def cleanup(): for i in cleanup_tasks: diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 5dc3795643..0e666492f6 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -15,7 +15,7 @@ from synapse.rest import admin from synapse.rest.client.v1 import login, room -from synapse.storage.data_stores.main import stats +from synapse.storage.databases.main import stats from tests import unittest @@ -42,16 +42,16 @@ class StatsRoomTests(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms_2", @@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def get_all_room_state(self): - return self.store.db.simple_select_list( + return self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store.db.simple_select_one( + self.store.db_pool.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -109,10 +109,10 @@ class StatsRoomTests(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) def test_initial_room(self): @@ -146,10 +146,10 @@ class StatsRoomTests(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) r = self.get_success(self.get_all_room_state()) @@ -186,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_update_one( + self.store.db_pool.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -196,17 +196,17 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Now, before the table is actually ingested, add some more events. @@ -217,7 +217,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms_2", @@ -226,7 +226,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -236,12 +236,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) self.reactor.advance(86401) @@ -703,15 +703,15 @@ class StatsRoomTests(unittest.HomeserverTestCase): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_delete( + self.store.db_pool.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store.db.simple_delete( + self.store.db_pool.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -723,9 +723,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): # now do the background updates - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms_2", @@ -735,7 +735,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -745,7 +745,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -756,10 +756,10 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) r1stats_complete = self._get_current_stats("room", r1) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 23fcc372dd..31ed89a5cd 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -339,7 +339,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -350,7 +350,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -362,10 +362,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -374,7 +374,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -384,7 +384,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -394,7 +394,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", @@ -437,10 +437,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) shares_private = self.get_users_who_share_private_rooms() @@ -476,10 +476,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) shares_private = self.get_users_who_share_private_rooms() diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 06575ba0a6..ae60874ec3 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -65,7 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Since we use sqlite in memory databases we need to make sure the # databases objects are the same. - self.worker_hs.get_datastore().db = hs.get_datastore().db + self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool self.test_handler = self._build_replication_data_handler() self.worker_hs.replication_data_handler = self.test_handler @@ -198,7 +198,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.streamer = self.hs.get_replication_streamer() store = self.hs.get_datastore() - self.database = store.db + self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" @@ -254,7 +254,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): ) store = worker_hs.get_datastore() - store.db._db_pool = self.database._db_pool + store.db_pool._db_pool = self.database_pool._db_pool repl_handler = ReplicationCommandHandler(worker_hs) client = ClientReplicationStreamProtocol( diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index cec1cf928f..408c568a27 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -566,7 +566,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): "state_groups_state", ): count = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", @@ -667,7 +667,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "state_groups_state", ): count = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 5a50e4fdd4..319e2c2325 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["hello"], ["there"]] self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "test", - self.storage.db.simple_upsert_many_txn, + self.storage.db_pool.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage.db.simple_select_list( + self.storage.db_pool.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["bleb"]] self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "test", - self.storage.db.simple_upsert_many_txn, + self.storage.db_pool.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage.db.simple_select_list( + self.storage.db_pool.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ef296e7dab..1b516b7976 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -24,11 +24,11 @@ from twisted.internet import defer from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.storage.data_stores.main.appservice import ( +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -391,7 +391,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(TestTransactionStore, self).__init__(database, db_conn, hs) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 940b166129..2efbc97c2e 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -9,7 +9,9 @@ from tests import unittest class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater + self.updates = ( + self.hs.get_datastore().db_pool.updates + ) # type: BackgroundUpdater # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -29,7 +31,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): store = self.hs.get_datastore() self.get_success( - store.db.simple_insert( + store.db_pool.simple_insert( "background_updates", values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, ) @@ -40,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def update(progress, count): yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield store.db.runInteraction( + yield store.db_pool.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index b589506c60..efcaeef1e7 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -21,7 +21,7 @@ from mock import Mock from twisted.internet import defer from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.engines import create_engine from tests import unittest @@ -57,7 +57,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) db._db_pool = self.db_pool self.datastore = SQLBaseStore(db, None, hs) @@ -66,7 +66,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_insert( + yield self.datastore.db_pool.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -78,7 +78,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_insert( + yield self.datastore.db_pool.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -93,7 +93,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.db.simple_select_one_onecol( + value = yield self.datastore.db_pool.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -107,7 +107,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.db.simple_select_one( + ret = yield self.datastore.db_pool.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -123,7 +123,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.db.simple_select_one( + ret = yield self.datastore.db_pool.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -138,7 +138,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.db.simple_select_list( + ret = yield self.datastore.db_pool.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -151,7 +151,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_update_one( + yield self.datastore.db_pool.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -166,7 +166,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_update_one( + yield self.datastore.db_pool.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -181,7 +181,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_delete_one( + yield self.datastore.db_pool.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 43425c969a..3fab5a5248 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -47,12 +47,12 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ # Make sure we don't clash with in progress updates. self.assertTrue( - self.store.db.updates._all_done, "Background updates are still ongoing" + self.store.db_pool.updates._all_done, "Background updates are still ongoing" ) schema_path = os.path.join( prepare_database.dir_path, - "data_stores", + "databases", "main", "schema", "delta", @@ -64,19 +64,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "test_delete_forward_extremities", run_delta_file ) ) # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) def test_soft_failed_extremities_handled_correctly(self): diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 3b483bc7f0..224ea6fd79 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -86,7 +86,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -117,7 +117,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -204,10 +204,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) user_id = "@user:id" @@ -225,7 +225,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But clear the associated entry in devices table self.get_success( - self.store.db.simple_update( + self.store.db_pool.simple_update( table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, @@ -252,7 +252,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -263,14 +263,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # We should now get the correct result again @@ -293,10 +293,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) user_id = "@user:id" @@ -315,7 +315,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -341,7 +341,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should get no results. result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 3aeec0dc0f..d4c3b867e3 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -56,7 +56,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) for i in range(0, 20): - self.get_success(self.store.db.runInteraction("insert", insert_event, i)) + self.get_success( + self.store.db_pool.runInteraction("insert", insert_event, i) + ) # this should get the last ten r = self.get_success(self.store.get_prev_events_for_room(room_id)) @@ -81,13 +83,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): for i in range(0, 20): self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room1) + self.store.db_pool.runInteraction("insert", insert_event, i, room1) ) self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room2) + self.store.db_pool.runInteraction("insert", insert_event, i, room2) ) self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room3) + self.store.db_pool.runInteraction("insert", insert_event, i, room3) ) # Test simple case @@ -164,7 +166,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): depth = depth_map[event_id] - self.store.db.simple_insert_txn( + self.store.db_pool.simple_insert_txn( txn, table="events", values={ @@ -179,7 +181,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): }, ) - self.store.db.simple_insert_many_txn( + self.store.db_pool.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -192,7 +194,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): for event_id in auth_graph: next_stream_ordering += 1 self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "insert", insert_event, event_id, next_stream_ordering ) ) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 2b1580feeb..857db071d4 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -60,7 +60,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.db.runInteraction( + counts = yield self.store.db_pool.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( @@ -81,7 +81,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.event_id, {user_id: action} ) ) - yield self.store.db.runInteraction( + yield self.store.db_pool.runInteraction( "", self.persist_events_store._set_push_actions_for_event_and_users_txn, [(event, None)], @@ -89,12 +89,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) def _rotate(stream): - return self.store.db.runInteraction( + return self.store.db_pool.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) def _mark_read(stream, depth): - return self.store.db.runInteraction( + return self.store.db_pool.runInteraction( "", self.store._remove_old_push_actions_before_txn, room_id, @@ -123,7 +123,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store.db.simple_delete( + yield self.store.db_pool.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -142,7 +142,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.db.simple_insert( + return self.store.db_pool.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 55e9ecf264..e845410dae 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -14,7 +14,7 @@ # limitations under the License. -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.unittest import HomeserverTestCase @@ -27,9 +27,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db = self.store.db # type: Database + self.db_pool = self.store.db_pool # type: DatabasePool - self.get_success(self.db.runInteraction("_setup_db", self._setup_db)) + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) def _setup_db(self, txn): txn.execute("CREATE SEQUENCE foobar_seq") @@ -47,7 +47,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def _create(conn): return MultiWriterIdGenerator( conn, - self.db, + self.db_pool, instance_name=instance_name, table="foobar", instance_column="instance_name", @@ -55,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): sequence_name="foobar_seq", ) - return self.get_success(self.db.runWithConnection(_create)) + return self.get_success(self.db_pool.runWithConnection(_create)) def _insert_rows(self, instance_name: str, number: int): def _insert(txn): @@ -65,7 +65,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): (instance_name,), ) - self.get_success(self.db.runInteraction("test_single_instance", _insert)) + self.get_success(self.db_pool.runInteraction("test_single_instance", _insert)) def test_empty(self): """Test an ID generator against an empty database gives sensible @@ -178,7 +178,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token("master"), 7) - self.get_success(self.db.runInteraction("test", _get_next_txn)) + self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token("master"), 8) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 9c04e92577..259f2215f1 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -78,7 +78,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # XXX why are we doing this here? this function is only run at startup # so it is odd to re-run it here. self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) ) @@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user, "email", email, now, now) ) - d = self.store.db.runInteraction( + d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.get_success(d) @@ -280,7 +280,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ] self.hs.config.mau_limits_reserved_threepids = threepids - d = self.store.db.runInteraction( + d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.get_success(d) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 0f0e1cd09b..41511d479f 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -343,7 +343,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -361,7 +361,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index f282921538..17c9da4838 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -179,10 +179,10 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Now let's create a room, which will insert a membership @@ -192,7 +192,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", @@ -203,12 +203,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) diff --git a/tests/unittest.py b/tests/unittest.py index 68d2586efd..2152c693f2 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -422,8 +422,8 @@ class HomeserverTestCase(TestCase): async def run_bg_updates(): with LoggingContext("run_bg_updates", request="run_bg_updates-1"): - while not await stor.db.updates.has_completed_background_updates(): - await stor.db.updates.do_next_background_update(1) + while not await stor.db_pool.updates.has_completed_background_updates(): + await stor.db_pool.updates.do_next_background_update(1) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() @@ -571,7 +571,7 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().db.simple_insert( + self.hs.get_datastore().db_pool.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", diff --git a/tox.ini b/tox.ini index 2b1db0f7f7..9a052c1e33 100644 --- a/tox.ini +++ b/tox.ini @@ -204,7 +204,7 @@ commands = mypy \ synapse/rest \ synapse/server_notices \ synapse/spam_checker_api \ - synapse/storage/data_stores/main/ui_auth.py \ + synapse/storage/databases/main/ui_auth.py \ synapse/storage/database.py \ synapse/storage/engines \ synapse/storage/state.py \ -- cgit 1.5.1 From 66f24449dd614b23ea4c572d8d613efeb129e4a2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Aug 2020 08:09:55 -0400 Subject: Improve performance of the register endpoint (#8009) --- changelog.d/8009.misc | 1 + synapse/api/errors.py | 4 +- synapse/handlers/auth.py | 19 +++-- synapse/rest/client/v2_alpha/account.py | 86 +++++++++++++++------- synapse/rest/client/v2_alpha/register.py | 108 ++++++++++++++++++---------- tests/rest/client/v2_alpha/test_register.py | 2 +- 6 files changed, 146 insertions(+), 74 deletions(-) create mode 100644 changelog.d/8009.misc (limited to 'tests') diff --git a/changelog.d/8009.misc b/changelog.d/8009.misc new file mode 100644 index 0000000000..3d58a11313 --- /dev/null +++ b/changelog.d/8009.misc @@ -0,0 +1 @@ +Improve the performance of the register endpoint. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b3bab1aa52..6e40630ab6 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -238,14 +238,16 @@ class InteractiveAuthIncompleteError(Exception): (This indicates we should return a 401 with 'result' as the body) Attributes: + session_id: The ID of the ongoing interactive auth session. result: the server response to the request, which should be passed back to the client """ - def __init__(self, result: "JsonDict"): + def __init__(self, session_id: str, result: "JsonDict"): super(InteractiveAuthIncompleteError, self).__init__( "Interactive auth not yet complete" ) + self.session_id = session_id self.result = result diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c7d921c21a..c24e7bafe0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -162,7 +162,7 @@ class AuthHandler(BaseHandler): request_body: Dict[str, Any], clientip: str, description: str, - ) -> dict: + ) -> Tuple[dict, str]: """ Checks that the user is who they claim to be, via a UI auth. @@ -183,9 +183,14 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - The parameters for this request (which may + A tuple of (params, session_id). + + 'params' contains the parameters for this request (which may have been given only in a previous call). + 'session_id' is the ID of this session, either passed in by the + client or assigned by this call + Raises: InteractiveAuthIncompleteError if the client has not yet completed any of the permitted login flows @@ -207,7 +212,7 @@ class AuthHandler(BaseHandler): flows = [[login_type] for login_type in self._supported_ui_auth_types] try: - result, params, _ = await self.check_auth( + result, params, session_id = await self.check_ui_auth( flows, request, request_body, clientip, description ) except LoginError: @@ -230,7 +235,7 @@ class AuthHandler(BaseHandler): if user_id != requester.user.to_string(): raise AuthError(403, "Invalid auth") - return params + return params, session_id def get_enabled_auth_types(self): """Return the enabled user-interactive authentication types @@ -240,7 +245,7 @@ class AuthHandler(BaseHandler): """ return self.checkers.keys() - async def check_auth( + async def check_ui_auth( self, flows: List[List[str]], request: SynapseRequest, @@ -363,7 +368,7 @@ class AuthHandler(BaseHandler): if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session.session_id) + session.session_id, self._auth_dict_for_flows(flows, session.session_id) ) # check auth type currently being presented @@ -410,7 +415,7 @@ class AuthHandler(BaseHandler): ret = self._auth_dict_for_flows(flows, session.session_id) ret["completed"] = list(creds) ret.update(errordict) - raise InteractiveAuthIncompleteError(ret) + raise InteractiveAuthIncompleteError(session.session_id, ret) async def add_oob_auth( self, stagetype: str, authdict: Dict[str, Any], clientip: str diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3767a809a4..fead85074b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -18,7 +18,12 @@ import logging from http import HTTPStatus from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + InteractiveAuthIncompleteError, + SynapseError, + ThreepidValidationError, +) from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( @@ -239,18 +244,12 @@ class PasswordRestServlet(RestServlet): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the new password provided to us. - if "new_password" in body: - new_password = body.pop("new_password") + new_password = body.pop("new_password", None) + if new_password is not None: if not isinstance(new_password, str) or len(new_password) > 512: raise SynapseError(400, "Invalid password") self.password_policy_handler.validate_password(new_password) - # If the password is valid, hash it and store it back on the body. - # This ensures that only the hashed password is handled everywhere. - if "new_password_hash" in body: - raise SynapseError(400, "Unexpected property: new_password_hash") - body["new_password_hash"] = await self.auth_handler.hash(new_password) - # there are two possibilities here. Either the user does not have an # access token, and needs to do a password reset; or they have one and # need to validate their identity. @@ -263,23 +262,49 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) - params = await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", - ) + try: + params, session_id = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise user_id = requester.user.to_string() else: requester = None - result, params, _ = await self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY]], - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", - ) + try: + result, params, session_id = await self.auth_handler.check_ui_auth( + [[LoginType.EMAIL_IDENTITY]], + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise if LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] @@ -304,12 +329,21 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - assert_params_in_dict(params, ["new_password_hash"]) - new_password_hash = params["new_password_hash"] + # If we have a password in this request, prefer it. Otherwise, there + # must be a password hash from an earlier request. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + else: + password_hash = await self.auth_handler.get_session_data( + session_id, "password_hash", None + ) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) + logout_devices = params.get("logout_devices", True) await self._set_password_handler.set_password( - user_id, new_password_hash, logout_devices, requester + user_id, password_hash, logout_devices, requester ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 370742ce59..a4c079196d 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -24,6 +24,7 @@ import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, + InteractiveAuthIncompleteError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -387,6 +388,7 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() + self._registration_enabled = self.hs.config.enable_registration self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -412,20 +414,8 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) - # we do basic sanity checks here because the auth layer will store these - # in sessions. Pull out the username/password provided to us. - if "password" in body: - password = body.pop("password") - if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") - self.password_policy_handler.validate_password(password) - - # If the password is valid, hash it and store it back on the body. - # This ensures that only the hashed password is handled everywhere. - if "password_hash" in body: - raise SynapseError(400, "Unexpected property: password_hash") - body["password_hash"] = await self.auth_handler.hash(password) - + # Pull out the provided username and do basic sanity checks early since + # the auth layer will store these in sessions. desired_username = None if "username" in body: if not isinstance(body["username"], str) or len(body["username"]) > 512: @@ -459,22 +449,35 @@ class RegisterRestServlet(RestServlet): ) return 200, result # we throw for non 200 responses - # for regular registration, downcase the provided username before - # attempting to register it. This should mean - # that people who try to register with upper-case in their usernames - # don't get a nasty surprise. (Note that we treat username - # case-insenstively in login, so they are free to carry on imagining - # that their username is CrAzYh4cKeR if that keeps them happy) - if desired_username is not None: - desired_username = desired_username.lower() - # == Normal User Registration == (everyone else) - if not self.hs.config.enable_registration: + if not self._registration_enabled: raise SynapseError(403, "Registration has been disabled") + # For regular registration, convert the provided username to lowercase + # before attempting to register it. This should mean that people who try + # to register with upper-case in their usernames don't get a nasty surprise. + # + # Note that we treat usernames case-insensitively in login, so they are + # free to carry on imagining that their username is CrAzYh4cKeR if that + # keeps them happy. + if desired_username is not None: + desired_username = desired_username.lower() + + # Check if this account is upgrading from a guest account. guest_access_token = body.get("guest_access_token", None) - if "initial_device_display_name" in body and "password_hash" not in body: + # Pull out the provided password and do basic sanity checks early. + # + # Note that we remove the password from the body since the auth layer + # will store the body in the session and we don't want a plaintext + # password store there. + password = body.pop("password", None) + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(password) + + if "initial_device_display_name" in body and password is None: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out @@ -484,6 +487,7 @@ class RegisterRestServlet(RestServlet): session_id = self.auth_handler.get_session_id(body) registered_user_id = None + password_hash = None if session_id: # if we get a registered user id out of here, it means we previously # registered a user for this session, so we could just return the @@ -492,7 +496,12 @@ class RegisterRestServlet(RestServlet): registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) + # Extract the previously-hashed password from the session. + password_hash = await self.auth_handler.get_session_data( + session_id, "password_hash", None + ) + # Ensure that the username is valid. if desired_username is not None: await self.registration_handler.check_username( desired_username, @@ -500,20 +509,38 @@ class RegisterRestServlet(RestServlet): assigned_user_id=registered_user_id, ) - auth_result, params, session_id = await self.auth_handler.check_auth( - self._registration_flows, - request, - body, - self.hs.get_ip_from_request(request), - "register a new account", - ) + # Check if the user-interactive authentication flows are complete, if + # not this will raise a user-interactive auth error. + try: + auth_result, params, session_id = await self.auth_handler.check_ui_auth( + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), + "register a new account", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth. + # + # Hash the password and store it with the session since the client + # is not required to provide the password again. + # + # If a password hash was previously stored we will not attempt to + # re-hash and store it for efficiency. This assumes the password + # does not change throughout the authentication flow, but this + # should be fine since the data is meant to be consistent. + if not password_hash and password: + password_hash = await self.auth_handler.hash(password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise # Check that we're not trying to register a denied 3pid. # # the user-facing checks will probably already have happened in # /register/email/requestToken when we requested a 3pid, but that's not # guaranteed. - if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: @@ -535,12 +562,15 @@ class RegisterRestServlet(RestServlet): # don't re-register the threepids registered = False else: - # NB: This may be from the auth handler and NOT from the POST - assert_params_in_dict(params, ["password_hash"]) + # If we have a password in this request, prefer it. Otherwise, there + # might be a password hash from an earlier request. + if password: + password_hash = await self.auth_handler.hash(password) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) desired_username = params.get("username", None) guest_access_token = params.get("guest_access_token", None) - new_password_hash = params.get("password_hash", None) if desired_username is not None: desired_username = desired_username.lower() @@ -582,7 +612,7 @@ class RegisterRestServlet(RestServlet): registered_user_id = await self.registration_handler.register_user( localpart=desired_username, - password_hash=new_password_hash, + password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, @@ -595,8 +625,8 @@ class RegisterRestServlet(RestServlet): ): await self.store.upsert_monthly_active_user(registered_user_id) - # remember that we've now registered that user account, and with - # what user ID (since the user may not have specified) + # Remember that the user account has been registered (and the user + # ID it was registered with, since it might not have been specified). await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 7deaf5b24a..53a43038f0 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -116,8 +116,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) + @override_config({"enable_registration": False}) def test_POST_disabled_registration(self): - self.hs.config.enable_registration = False request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) -- cgit 1.5.1 From d4a7829b12197faf52eb487c443ee09acafeb37e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Aug 2020 08:30:06 -0400 Subject: Convert synapse.api to async/await (#8031) --- changelog.d/8031.misc | 1 + synapse/api/auth.py | 123 ++++++++++----------- synapse/api/auth_blocking.py | 13 +-- synapse/api/filtering.py | 7 +- synapse/events/builder.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 2 +- synapse/module_api/__init__.py | 8 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/rest/client/v1/directory.py | 2 +- synapse/rest/client/v2_alpha/register.py | 2 +- synapse/storage/databases/main/client_ips.py | 5 +- tests/api/test_auth.py | 69 +++++++----- tests/api/test_filtering.py | 36 ++++-- tests/handlers/test_typing.py | 4 +- tests/rest/admin/test_user.py | 10 +- tests/rest/client/v1/test_profile.py | 4 +- tests/rest/client/v1/test_rooms.py | 6 +- tests/rest/client/v1/test_typing.py | 6 +- .../test_resource_limits_server_notices.py | 2 +- tests/unittest.py | 24 ++-- 22 files changed, 172 insertions(+), 160 deletions(-) create mode 100644 changelog.d/8031.misc (limited to 'tests') diff --git a/changelog.d/8031.misc b/changelog.d/8031.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8031.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 2178e623da..d8190f92ab 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import List, Optional, Tuple import pymacaroons from netaddr import IPAddress -from twisted.internet import defer from twisted.web.server import Request import synapse.types @@ -80,13 +79,14 @@ class Auth(object): self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key - @defer.inlineCallbacks - def check_from_context(self, room_version: str, event, context, do_sig_check=True): - prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) - auth_events_ids = yield self.compute_auth_events( + async def check_from_context( + self, room_version: str, event, context, do_sig_check=True + ): + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version_obj = KNOWN_ROOM_VERSIONS[room_version] @@ -94,14 +94,13 @@ class Auth(object): room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check ) - @defer.inlineCallbacks - def check_user_in_room( + async def check_user_in_room( self, room_id: str, user_id: str, current_state: Optional[StateMap[EventBase]] = None, allow_departed_users: bool = False, - ): + ) -> EventBase: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. @@ -119,37 +118,35 @@ class Auth(object): Raises: AuthError if the user is/was not in the room. Returns: - Deferred[Optional[EventBase]]: - Membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. + Membership event for the user if the user was in the + room. This will be the join event if they are currently joined to + the room. This will be the leave event if they have left the room. """ if current_state: member = current_state.get((EventTypes.Member, user_id), None) else: - member = yield defer.ensureDeferred( - self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id - ) + member = await self.state.get_current_state( + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) - membership = member.membership if member else None - if membership == Membership.JOIN: - return member + if member: + membership = member.membership - # XXX this looks totally bogus. Why do we not allow users who have been banned, - # or those who were members previously and have been re-invited? - if allow_departed_users and membership == Membership.LEAVE: - forgot = yield self.store.did_forget(user_id, room_id) - if not forgot: + if membership == Membership.JOIN: return member + # XXX this looks totally bogus. Why do we not allow users who have been banned, + # or those who were members previously and have been re-invited? + if allow_departed_users and membership == Membership.LEAVE: + forgot = await self.store.did_forget(user_id, room_id) + if not forgot: + return member + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) - @defer.inlineCallbacks - def check_host_in_room(self, room_id, host): + async def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): - latest_event_ids = yield self.store.is_host_joined(room_id, host) + latest_event_ids = await self.store.is_host_joined(room_id, host) return latest_event_ids def can_federate(self, event, auth_events): @@ -160,14 +157,13 @@ class Auth(object): def get_public_keys(self, invite_event): return event_auth.get_public_keys(invite_event) - @defer.inlineCallbacks - def get_user_by_req( + async def get_user_by_req( self, request: Request, allow_guest: bool = False, rights: str = "access", allow_expired: bool = False, - ): + ) -> synapse.types.Requester: """ Get a registered user's ID. Args: @@ -180,7 +176,7 @@ class Auth(object): /login will deliver access tokens regardless of expiration. Returns: - defer.Deferred: resolves to a `synapse.types.Requester` object + Resolves to the requester Raises: InvalidClientCredentialsError if no user by that token exists or the token is invalid. @@ -194,14 +190,14 @@ class Auth(object): access_token = self.get_access_token_from_request(request) - user_id, app_service = yield self._get_appservice_user_id(request) + user_id, app_service = await self._get_appservice_user_id(request) if user_id: request.authenticated_entity = user_id opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("appservice_id", app_service.id) if ip_addr and self._track_appservice_user_ips: - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id=user_id, access_token=access_token, ip=ip_addr, @@ -211,7 +207,7 @@ class Auth(object): return synapse.types.create_requester(user_id, app_service=app_service) - user_info = yield self.get_user_by_access_token( + user_info = await self.get_user_by_access_token( access_token, rights, allow_expired=allow_expired ) user = user_info["user"] @@ -221,7 +217,7 @@ class Auth(object): # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: user_id = user.to_string() - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) if ( expiration_ts is not None and self.clock.time_msec() >= expiration_ts @@ -235,7 +231,7 @@ class Auth(object): device_id = user_info.get("device_id") if user and access_token and ip_addr: - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id=user.to_string(), access_token=access_token, ip=ip_addr, @@ -261,8 +257,7 @@ class Auth(object): except KeyError: raise MissingClientTokenError() - @defer.inlineCallbacks - def _get_appservice_user_id(self, request): + async def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) @@ -283,14 +278,13 @@ class Auth(object): if not app_service.is_interested_in_user(user_id): raise AuthError(403, "Application service cannot masquerade as this user.") - if not (yield self.store.get_user_by_id(user_id)): + if not (await self.store.get_user_by_id(user_id)): raise AuthError(403, "Application service has not registered this user") return user_id, app_service - @defer.inlineCallbacks - def get_user_by_access_token( + async def get_user_by_access_token( self, token: str, rights: str = "access", allow_expired: bool = False, - ): + ) -> dict: """ Validate access token and get user_id from it Args: @@ -300,7 +294,7 @@ class Auth(object): allow_expired: If False, raises an InvalidClientTokenError if the token is expired Returns: - Deferred[dict]: dict that includes: + dict that includes: `user` (UserID) `is_guest` (bool) `token_id` (int|None): access token id. May be None if guest @@ -314,7 +308,7 @@ class Auth(object): if rights == "access": # first look in the database - r = yield self._look_up_user_by_access_token(token) + r = await self._look_up_user_by_access_token(token) if r: valid_until_ms = r["valid_until_ms"] if ( @@ -352,7 +346,7 @@ class Auth(object): # It would of course be much easier to store guest access # tokens in the database as well, but that would break existing # guest tokens. - stored_user = yield self.store.get_user_by_id(user_id) + stored_user = await self.store.get_user_by_id(user_id) if not stored_user: raise InvalidClientTokenError("Unknown user_id %s" % user_id) if not stored_user["is_guest"]: @@ -482,9 +476,8 @@ class Auth(object): now = self.hs.get_clock().time_msec() return now < expiry - @defer.inlineCallbacks - def _look_up_user_by_access_token(self, token): - ret = yield self.store.get_user_by_access_token(token) + async def _look_up_user_by_access_token(self, token): + ret = await self.store.get_user_by_access_token(token) if not ret: return None @@ -507,7 +500,7 @@ class Auth(object): logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() request.authenticated_entity = service.sender - return defer.succeed(service) + return service async def is_server_admin(self, user: UserID) -> bool: """ Check if the given user is a local server admin. @@ -522,7 +515,7 @@ class Auth(object): def compute_auth_events( self, event, current_state_ids: StateMap[str], for_verification: bool = False, - ): + ) -> List[str]: """Given an event and current state return the list of event IDs used to auth an event. @@ -530,11 +523,11 @@ class Auth(object): should be added to the event's `auth_events`. Returns: - defer.Deferred(list[str]): List of event IDs. + List of event IDs. """ if event.type == EventTypes.Create: - return defer.succeed([]) + return [] # Currently we ignore the `for_verification` flag even though there are # some situations where we can drop particular auth events when adding @@ -553,7 +546,7 @@ class Auth(object): if auth_ev_id: auth_ids.append(auth_ev_id) - return defer.succeed(auth_ids) + return auth_ids async def check_can_change_room_list(self, room_id: str, user: UserID): """Determine whether the user is allowed to edit the room's entry in the @@ -636,10 +629,9 @@ class Auth(object): return query_params[0].decode("ascii") - @defer.inlineCallbacks - def check_user_in_room_or_world_readable( + async def check_user_in_room_or_world_readable( self, room_id: str, user_id: str, allow_departed_users: bool = False - ): + ) -> Tuple[str, Optional[str]]: """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. @@ -650,10 +642,9 @@ class Auth(object): members but have now departed Returns: - Deferred[tuple[str, str|None]]: Resolves to the current membership of - the user in the room and the membership event ID of the user. If - the user is not in the room and never has been, then - `(Membership.JOIN, None)` is returned. + Resolves to the current membership of the user in the room and the + membership event ID of the user. If the user is not in the room and + never has been, then `(Membership.JOIN, None)` is returned. """ try: @@ -662,15 +653,13 @@ class Auth(object): # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.check_user_in_room( + member_event = await self.check_user_in_room( room_id, user_id, allow_departed_users=allow_departed_users ) return member_event.membership, member_event.event_id except AuthError: - visibility = yield defer.ensureDeferred( - self.state.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" - ) + visibility = await self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" ) if ( visibility diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 5c499b6b4e..49093bf181 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.errors import Codes, ResourceLimitError from synapse.config.server import is_threepid_reserved @@ -36,8 +34,7 @@ class AuthBlocking(object): self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids - @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag @@ -60,7 +57,7 @@ class AuthBlocking(object): if user_id is not None: if user_id == self._server_notices_mxid: return - if (yield self.store.is_support_user(user_id)): + if await self.store.is_support_user(user_id): return if self._hs_disabled: @@ -76,11 +73,11 @@ class AuthBlocking(object): # If the user is already part of the MAU cohort or a trial user if user_id: - timestamp = yield self.store.user_last_seen_monthly_active(user_id) + timestamp = await self.store.user_last_seen_monthly_active(user_id) if timestamp: return - is_trial = yield self.store.is_trial_user(user_id) + is_trial = await self.store.is_trial_user(user_id) if is_trial: return elif threepid: @@ -93,7 +90,7 @@ class AuthBlocking(object): # allow registration. Support users are excluded from MAU checks. return # Else if there is no room in the MAU bucket, bail - current_mau = yield self.store.get_monthly_active_count() + current_mau = await self.store.get_monthly_active_count() if current_mau >= self._max_mau_value: raise ResourceLimitError( 403, diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index f988f62a1e..7393d6cb74 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -21,8 +21,6 @@ import jsonschema from canonicaljson import json from jsonschema import FormatChecker -from twisted.internet import defer - from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState @@ -137,9 +135,8 @@ class Filtering(object): super(Filtering, self).__init__() self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_user_filter(self, user_localpart, filter_id): - result = yield self.store.get_user_filter(user_localpart, filter_id) + async def get_user_filter(self, user_localpart, filter_id): + result = await self.store.get_user_filter(user_localpart, filter_id) return FilterCollection(result) def add_user_filter(self, user_localpart, user_filter): diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 69b53ca2bc..4e179d49b3 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -106,7 +106,7 @@ class EventBuilder(object): state_ids = await self._state.get_current_state_ids( self.room_id, prev_event_ids ) - auth_ids = await self._auth.compute_auth_events(self, state_ids) + auth_ids = self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b3764dedae..593932adb7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2064,7 +2064,7 @@ class FederationHandler(BaseHandler): if not auth_events: prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events_x = await self.store.get_events(auth_events_ids) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 43901d0934..708533d4d1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1061,7 +1061,7 @@ class EventCreationHandler(object): raise SynapseError(400, "Cannot redact event from a different room") prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events = await self.store.get_events(auth_events_ids) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 8201849951..c2fb757d9a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -194,12 +194,16 @@ class ModuleApi(object): synapse.api.errors.AuthError: the access token is invalid """ # see if the access token corresponds to a device - user_info = yield self._auth.get_user_by_access_token(access_token) + user_info = yield defer.ensureDeferred( + self._auth.get_user_by_access_token(access_token) + ) device_id = user_info.get("device_id") user_id = user_info["user"].to_string() if device_id: # delete the device, which will also delete its access tokens - yield self._hs.get_device_handler().delete_device(user_id, device_id) + yield defer.ensureDeferred( + self._hs.get_device_handler().delete_device(user_id, device_id) + ) else: # no associated device. Just delete the access token. yield defer.ensureDeferred( diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 04b9d8ac82..e7fcee0e87 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -120,7 +120,7 @@ class BulkPushRuleEvaluator(object): pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: pl_event} else: - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events = await self.store.get_events(auth_events_ids) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 60dd3f6701..a6fdedde63 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -28,7 +28,7 @@ class SlavedClientIpStore(BaseSlavedStore): name="client_ip_last_seen", keylen=4, max_entries=50000 ) - def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): + async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) key = (user_id, access_token, ip) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 5934b1fe8b..b210015173 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet): dir_handler = self.handlers.directory_handler try: - service = await self.auth.get_appservice_by_req(request) + service = self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) await dir_handler.delete_appservice_association(service, room_alias) logger.info( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a4c079196d..c549c090b3 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -424,7 +424,7 @@ class RegisterRestServlet(RestServlet): appservice = None if self.auth.has_access_token(request): - appservice = await self.auth.get_appservice_by_req(request) + appservice = self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes which have completely # different registration flows to normal users diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 712c8d0264..50d71f5ebc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -380,8 +380,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - @defer.inlineCallbacks - def insert_client_ip( + async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): if not now: @@ -392,7 +391,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - yield self.populate_monthly_active_users(user_id) + await self.populate_monthly_active_users(user_id) # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 0bfb86bf1f..5d45689c8c 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase): # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) + self.store.insert_client_ip = Mock(return_value=defer.succeed(None)) self.store.is_support_user = Mock(return_value=defer.succeed(False)) @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase): self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): user_info = {"name": self.test_user, "token_id": "ditto"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase): token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" @@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") @@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + # This just needs to return a truth-y value. + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) self.failureResultOf(d, AuthError) @defer.inlineCallbacks def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( - return_value={"name": "@baldrick:matrix.org", "device_id": "device"} + return_value=defer.succeed( + {"name": "@baldrick:matrix.org", "device_id": "device"} + ) ) user_id = "@baldrick:matrix.org" @@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): - self.store.get_user_by_id = Mock(return_value={"is_guest": True}) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True})) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase): def get_user(tok): if token != tok: - return None - return { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + return defer.succeed(None) + return defer.succeed( + { + "name": USER_ID, + "is_guest": False, + "token_id": 1234, + "device_id": "DEVICE", + } + ) self.store.get_user_by_access_token = get_user - self.store.get_user_by_id = Mock(return_value={"is_guest": False}) + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) # check the token works request = Mock(args={}) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 4e67503cf0..1fab1d6b69 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart + "2", filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart + "2", filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events=events) @@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events) @@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase): self.assertEquals( user_filter_json, ( - yield self.datastore.get_user_filter( - user_localpart=user_localpart, filter_id=0 + yield defer.ensureDeferred( + self.datastore.get_user_filter( + user_localpart=user_localpart, filter_id=0 + ) ) ), ) @@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase): user_localpart=user_localpart, user_filter=user_filter_json ) - filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) self.assertEquals(filter.get_filter_json(), user_filter_json) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5878f74175..b7d0adb10e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -126,10 +126,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - def check_user_in_room(room_id, user_id): + async def check_user_in_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") - return defer.succeed(None) + return None hs.get_auth().check_user_in_room = check_user_in_room diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index f16eef15f7..17d0aae2e9 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -20,6 +20,8 @@ import urllib.parse from mock import Mock +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import HttpResponseException, ResourceLimitError @@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): store = self.hs.get_datastore() # Set monthly active users to the limit - store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value) + store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit self.get_failure( @@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + return_value=defer.succeed(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + return_value=defer.succeed(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 8df58b4a63..ace0a3c08d 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase): profile_handler=self.mock_handler, ) - def _get_user_by_req(request=None, allow_guest=False): - return defer.succeed(synapse.types.create_requester(myid)) + async def _get_user_by_req(request=None, allow_guest=False): + return synapse.types.create_requester(myid) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5ccda8b2bd..ef6b775ed2 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -23,8 +23,6 @@ from urllib import parse as urlparse from mock import Mock -from twisted.internet import defer - import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus @@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase): self.hs.get_federation_handler = Mock(return_value=Mock()) - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None self.hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 18260bb90e..94d2bf2eb1 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_handlers().federation_handler = Mock() - def get_user_by_access_token(token=None, allow_guest=False): + async def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, @@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_auth().get_user_by_access_token = get_user_by_access_token - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 7f70353b0d..3f88abe3d2 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -258,7 +258,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock(return_value=1000) + self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000)) self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(1000) diff --git a/tests/unittest.py b/tests/unittest.py index 2152c693f2..d0bba3ddef 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: - def get_user_by_access_token(token=None, allow_guest=False): - return succeed( - { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } - ) - - def get_user_by_req(request, allow_guest=False, rights="access"): - return succeed( - create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None - ) + async def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.helper.auth_user_id), + "token_id": 1, + "is_guest": False, + } + + async def get_user_by_req(request, allow_guest=False, rights="access"): + return create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None ) self.hs.get_auth().get_user_by_req = get_user_by_req -- cgit 1.5.1 From fe6cfc80ec6ed3b9e29ca74cde5dcfae3d8236ea Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Aug 2020 08:39:35 -0400 Subject: Convert some util functions to async (#8035) --- changelog.d/8035.misc | 1 + synapse/util/metrics.py | 39 ++++++++++++++++++++------------------ synapse/util/retryutils.py | 16 ++++++---------- tests/util/test_retryutils.py | 44 +++++++++++-------------------------------- 4 files changed, 39 insertions(+), 61 deletions(-) create mode 100644 changelog.d/8035.misc (limited to 'tests') diff --git a/changelog.d/8035.misc b/changelog.d/8035.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8035.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index ec61e14423..a805f51df1 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,14 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging from functools import wraps from prometheus_client import Counter -from twisted.internet import defer - from synapse.logging.context import LoggingContext, current_context from synapse.metrics import InFlightGauge @@ -62,25 +59,31 @@ in_flight = InFlightGauge( def measure_func(name=None): - def wrapper(func): - block_name = func.__name__ if name is None else name + """ + Used to decorate an async function with a `Measure` context manager. + + Usage: - if inspect.iscoroutinefunction(func): + @measure_func() + async def foo(...): + ... - @wraps(func) - async def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = await func(self, *args, **kwargs) - return r + Which is analogous to: - else: + async def foo(...): + with Measure(...): + ... + + """ + + def wrapper(func): + block_name = func.__name__ if name is None else name - @wraps(func) - @defer.inlineCallbacks - def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = yield func(self, *args, **kwargs) - return r + @wraps(func) + async def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = await func(self, *args, **kwargs) + return r return measured_func diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 8794317caa..919988d3bc 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -15,8 +15,6 @@ import logging import random -from twisted.internet import defer - import synapse.logging.context from synapse.api.errors import CodeMessageException @@ -54,8 +52,7 @@ class NotRetryingDestination(Exception): self.destination = destination -@defer.inlineCallbacks -def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): +async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): """For a given destination check if we have previously failed to send a request there and are waiting before retrying the destination. If we are not ready to retry the destination, this will raise a @@ -73,9 +70,9 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) Example usage: try: - limiter = yield get_retry_limiter(destination, clock, store) + limiter = await get_retry_limiter(destination, clock, store) with limiter: - response = yield do_request() + response = await do_request() except NotRetryingDestination: # We aren't ready to retry that destination. raise @@ -83,7 +80,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) failure_ts = None retry_last_ts, retry_interval = (0, 0) - retry_timings = yield store.get_destination_retry_timings(destination) + retry_timings = await store.get_destination_retry_timings(destination) if retry_timings: failure_ts = retry_timings["failure_ts"] @@ -222,10 +219,9 @@ class RetryDestinationLimiter(object): if self.failure_ts is None: self.failure_ts = retry_last_ts - @defer.inlineCallbacks - def store_retry_timings(): + async def store_retry_timings(): try: - yield self.store.set_destination_retry_timings( + await self.store.set_destination_retry_timings( self.destination, self.failure_ts, retry_last_ts, diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 9e348694ad..bc42ffce88 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase): def test_new_destination(self): """A happy-path case with a new destination and a successful operation""" store = self.hs.get_datastore() - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) # advance the clock a bit before making the request self.pump(1) @@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase): with limiter: pass - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) def test_limiter(self): """General test case which walks through the process of a failing request""" store = self.hs.get_datastore() - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) try: @@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase): except AssertionError: pass - # wait for the update to land - self.pump() - - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertEqual(new_timings["failure_ts"], failure_ts) self.assertEqual(new_timings["retry_last_ts"], failure_ts) self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL) # now if we try again we should get a failure - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - self.failureResultOf(d, NotRetryingDestination) + self.get_failure( + get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination + ) # # advance the clock and try again # self.pump(MIN_RETRY_INTERVAL) - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) try: @@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase): except AssertionError: pass - # wait for the update to land - self.pump() - - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertEqual(new_timings["failure_ts"], failure_ts) self.assertEqual(new_timings["retry_last_ts"], retry_ts) self.assertGreaterEqual( @@ -110,9 +92,7 @@ class RetryLimiterTestCase(HomeserverTestCase): # one more go, with success # self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) with limiter: @@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase): # wait for the update to land self.pump() - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) -- cgit 1.5.1 From 2ffd6783c7af12e3c29e1a44dee4a9deeb83890b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 6 Aug 2020 17:15:35 +0100 Subject: Revert #7736 (#8039) --- changelog.d/7736.feature | 1 - changelog.d/8039.misc | 1 + scripts/synapse_port_db | 2 +- synapse/handlers/sync.py | 6 - synapse/push/push_tools.py | 17 ++- synapse/rest/client/v2_alpha/sync.py | 1 - synapse/storage/databases/main/cache.py | 1 - synapse/storage/databases/main/events.py | 48 +------ synapse/storage/databases/main/events_worker.py | 86 +---------- .../main/schema/delta/58/12unread_messages.sql | 18 --- tests/rest/client/v1/utils.py | 20 --- tests/rest/client/v2_alpha/test_sync.py | 157 +-------------------- 12 files changed, 19 insertions(+), 339 deletions(-) delete mode 100644 changelog.d/7736.feature create mode 100644 changelog.d/8039.misc delete mode 100644 synapse/storage/databases/main/schema/delta/58/12unread_messages.sql (limited to 'tests') diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature deleted file mode 100644 index feb02be234..0000000000 --- a/changelog.d/7736.feature +++ /dev/null @@ -1 +0,0 @@ -Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654). diff --git a/changelog.d/8039.misc b/changelog.d/8039.misc new file mode 100644 index 0000000000..599933c80e --- /dev/null +++ b/changelog.d/8039.misc @@ -0,0 +1 @@ +Revert MSC2654 implementation because of perf issues. Please delete this line when processing the 1.19 changelog. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index ae5e1810fc..a34bdf1830 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -67,7 +67,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url", "count_as_unread"], + "events": ["processed", "outlier", "contains_url"], "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5a19bac929..c42dac18f5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -103,7 +103,6 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) - unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -1887,10 +1886,6 @@ class SyncHandler(object): if room_builder.rtype == "joined": unread_notifications = {} # type: Dict[str, str] - - unread_count = await self.store.get_unread_message_count_for_user( - room_id, sync_config.user.to_string(), - ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1899,7 +1894,6 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, - unread_count=unread_count, ) if room_sync or always_include: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index bc8f71916b..d0145666bf 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,13 +21,22 @@ async def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) + my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + badge = len(invites) for room_id in joins: - unread_count = await store.get_unread_message_count_for_user(room_id, user_id) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if unread_count else 0 + if room_id in my_receipts_by_room: + last_unread_event_id = my_receipts_by_room[room_id] + + notifs = await ( + store.get_unread_event_push_actions_by_room_for_user( + room_id, user_id, last_unread_event_id + ) + ) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if notifs["notify_count"] else 0 return badge diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 3f5bf75e59..a5c24fbd63 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -426,7 +426,6 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary - result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 683afde52b..10de446065 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -172,7 +172,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4d8a24ce4b..1a68bf32cb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -53,47 +53,6 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) -STATE_EVENT_TYPES_TO_MARK_UNREAD = { - EventTypes.Topic, - EventTypes.Name, - EventTypes.RoomAvatar, - EventTypes.Tombstone, -} - - -def should_count_as_unread(event: EventBase, context: EventContext) -> bool: - # Exclude rejected and soft-failed events. - if context.rejected or event.internal_metadata.is_soft_failed(): - return False - - # Exclude notices. - if ( - not event.is_state() - and event.type == EventTypes.Message - and event.content.get("msgtype") == "m.notice" - ): - return False - - # Exclude edits. - relates_to = event.content.get("m.relates_to", {}) - if relates_to.get("rel_type") == RelationTypes.REPLACE: - return False - - # Mark events that have a non-empty string body as unread. - body = event.content.get("body") - if isinstance(body, str) and body: - return True - - # Mark some state events as unread. - if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: - return True - - # Mark encrypted events as unread. - if not event.is_state() and event.type == EventTypes.Encrypted: - return True - - return False - def encode_json(json_object): """ @@ -239,10 +198,6 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - self.store.get_unread_message_count_for_user.invalidate_many( - (event.room_id,), - ) - for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -864,9 +819,8 @@ class PersistEventsStore: "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), - "count_as_unread": should_count_as_unread(event, context), } - for event, context in events_and_contexts + for event, _ in events_and_contexts ], ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7b7393f6e..755b7a2a85 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -41,15 +41,9 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.types import Cursor from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import ( - Cache, - _CacheContext, - cached, - cachedInlineCallbacks, -) +from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1364,84 +1358,6 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - @cached(tree=True, cache_context=True) - async def get_unread_message_count_for_user( - self, room_id: str, user_id: str, cache_context: _CacheContext, - ) -> int: - """Retrieve the count of unread messages for the given room and user. - - Args: - room_id: The ID of the room to count unread messages in. - user_id: The ID of the user to count unread messages for. - - Returns: - The number of unread messages for the given user in the given room. - """ - with Measure(self._clock, "get_unread_message_count_for_user"): - last_read_event_id = await self.get_last_receipt_event_id_for_user( - user_id=user_id, - room_id=room_id, - receipt_type="m.read", - on_invalidate=cache_context.invalidate, - ) - - return await self.db_pool.runInteraction( - "get_unread_message_count_for_user", - self._get_unread_message_count_for_user_txn, - user_id, - room_id, - last_read_event_id, - ) - - def _get_unread_message_count_for_user_txn( - self, - txn: Cursor, - user_id: str, - room_id: str, - last_read_event_id: Optional[str], - ) -> int: - if last_read_event_id: - # Get the stream ordering for the last read event. - stream_ordering = self.db_pool.simple_select_one_onecol_txn( - txn=txn, - table="events", - keyvalues={"room_id": room_id, "event_id": last_read_event_id}, - retcol="stream_ordering", - ) - else: - # If there's no read receipt for that room, it probably means the user hasn't - # opened it yet, in which case use the stream ID of their join event. - # We can't just set it to 0 otherwise messages from other local users from - # before this user joined will be counted as well. - txn.execute( - """ - SELECT stream_ordering FROM local_current_membership - LEFT JOIN events USING (event_id, room_id) - WHERE membership = 'join' - AND user_id = ? - AND room_id = ? - """, - (user_id, room_id), - ) - row = txn.fetchone() - - if row is None: - return 0 - - stream_ordering = row[0] - - # Count the messages that qualify as unread after the stream ordering we've just - # retrieved. - sql = """ - SELECT COUNT(*) FROM events - WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread - """ - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - - return row[0] if row else 0 - AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql b/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql deleted file mode 100644 index 531b532c73..0000000000 --- a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Store a boolean value in the events table for whether the event should be counted in --- the unread_count property of sync responses. -ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 51941f99f9..8933b560d2 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -165,26 +165,6 @@ class RestHelper(object): return channel.json_body - def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - - path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id) - if tok: - path = path + "?access_token=%s" % tok - - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8") - ) - render(request, self.resource, self.hs.get_reactor()) - - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) - ) - - return channel.json_body - def _read_write_state( self, room_id: str, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index a31e44c97e..fa3a3ec1bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import read_marker, sync +from synapse.rest.client.v2_alpha import sync from tests import unittest from tests.server import TimedOutException @@ -324,156 +324,3 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) - - -class UnreadMessagesTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - read_marker.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to check the unread counts). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll check unread counts for. - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user (used to send events to the room). - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Change the power levels of the room so that the second user can send state - # events. - self.helper.send_state( - self.room_id, - EventTypes.PowerLevels, - { - "users": {self.user_id: 100, self.user2: 100}, - "users_default": 0, - "events": { - "m.room.name": 50, - "m.room.power_levels": 100, - "m.room.history_visibility": 100, - "m.room.canonical_alias": 50, - "m.room.avatar": 50, - "m.room.tombstone": 100, - "m.room.server_acl": 100, - "m.room.encryption": 100, - }, - "events_default": 0, - "state_default": 50, - "ban": 50, - "kick": 50, - "redact": 50, - "invite": 0, - }, - tok=self.tok, - ) - - def test_unread_counts(self): - """Tests that /sync returns the right value for the unread count (MSC2654).""" - - # Check that our own messages don't increase the unread count. - self.helper.send(self.room_id, "hello", tok=self.tok) - self._check_unread_count(0) - - # Join the new user and check that this doesn't increase the unread count. - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - self._check_unread_count(0) - - # Check that the new user sending a message increases our unread count. - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - request, channel = self.make_request( - "POST", - "/rooms/%s/read_markers" % self.room_id, - body, - access_token=self.tok, - ) - self.render(request) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that room name changes increase the unread counter. - self.helper.send_state( - self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, - ) - self._check_unread_count(1) - - # Check that room topic changes increase the unread counter. - self.helper.send_state( - self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, - ) - self._check_unread_count(2) - - # Check that encrypted messages increase the unread counter. - self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) - self._check_unread_count(3) - - # Check that custom events with a body increase the unread counter. - self.helper.send_event( - self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that edits don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "body": "hello", - "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, - }, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that notices don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"body": "hello", "msgtype": "m.notice"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that tombstone events changes increase the unread counter. - self.helper.send_state( - self.room_id, - EventTypes.Tombstone, - {"replacement_room": "!someroom:test"}, - tok=self.tok2, - ) - self._check_unread_count(5) - - def _check_unread_count(self, expected_count: True): - """Syncs and compares the unread count with the expected value.""" - - request, channel = self.make_request( - "GET", self.url % self.next_batch, access_token=self.tok, - ) - self.render(request) - - self.assertEqual(channel.code, 200, channel.json_body) - - room_entry = channel.json_body["rooms"]["join"][self.room_id] - self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, - ) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] -- cgit 1.5.1 From 7620912d84f6a8b24143f1340dd653f44b13bf30 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2020 14:21:24 +0100 Subject: Add health check endpoint (#8048) --- changelog.d/8048.feature | 1 + docs/reverse_proxy.md | 7 +++++++ synapse/app/generic_worker.py | 6 +++++- synapse/app/homeserver.py | 5 ++++- synapse/http/site.py | 9 ++++++++- synapse/rest/health.py | 31 +++++++++++++++++++++++++++++++ tests/rest/test_health.py | 34 ++++++++++++++++++++++++++++++++++ 7 files changed, 90 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8048.feature create mode 100644 synapse/rest/health.py create mode 100644 tests/rest/test_health.py (limited to 'tests') diff --git a/changelog.d/8048.feature b/changelog.d/8048.feature new file mode 100644 index 0000000000..8521d1920e --- /dev/null +++ b/changelog.d/8048.feature @@ -0,0 +1 @@ +Add a `/health` endpoint to every configured HTTP listener that can be used as a health check endpoint by load balancers. diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index 7bfb96eff6..fd48ba0874 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -139,3 +139,10 @@ client IP addresses are recorded correctly. Having done so, you can then use `https://matrix.example.com` (instead of `https://matrix.example.com:8448`) as the "Custom server" when connecting to Synapse from a client. + + +## Health check endpoint + +Synapse exposes a health check endpoint for use by reverse proxies. +Each configured HTTP listener has a `/health` endpoint which always returns +200 OK (and doesn't get logged). diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1a16d0b9f8..7957586d69 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -123,6 +123,7 @@ from synapse.rest.client.v2_alpha.account_data import ( from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet +from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer from synapse.storage.databases.main.censor_events import CensorEventsStore @@ -493,7 +494,10 @@ class GenericWorkerServer(HomeServer): site_tag = listener_config.http_options.tag if site_tag is None: site_tag = port - resources = {} + + # We always include a health resource. + resources = {"/health": HealthResource()} + for res in listener_config.http_options.resources: for name in res.names: if name == "metrics": diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index d87a77718e..98d0d14a12 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -68,6 +68,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource +from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer @@ -98,7 +99,9 @@ class SynapseHomeServer(HomeServer): if site_tag is None: site_tag = port - resources = {} + # We always include a health resource. + resources = {"/health": HealthResource()} + for res in listener_config.http_options.resources: for name in res.names: if name == "openid" and "federation" in res.names: diff --git a/synapse/http/site.py b/synapse/http/site.py index f506152fea..79a9229a26 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -286,7 +286,9 @@ class SynapseRequest(Request): # the connection dropped) code += "!" - self.site.access_logger.info( + log_level = logging.INFO if self._should_log_request() else logging.DEBUG + self.site.access_logger.log( + log_level, "%s - %s - {%s}" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" ' %sB %s "%s %s %s" "%s" [%d dbevts]', @@ -314,6 +316,11 @@ class SynapseRequest(Request): except Exception as e: logger.warning("Failed to stop metrics: %r", e) + def _should_log_request(self) -> bool: + """Whether we should log at INFO that we processed the request. + """ + return self.path != b"/health" + class XForwardedForRequest(SynapseRequest): def __init__(self, *args, **kw): diff --git a/synapse/rest/health.py b/synapse/rest/health.py new file mode 100644 index 0000000000..0170950bf3 --- /dev/null +++ b/synapse/rest/health.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.resource import Resource + + +class HealthResource(Resource): + """A resource that does nothing except return a 200 with a body of `OK`, + which can be used as a health check. + + Note: `SynapseRequest._should_log_request` ensures that requests to + `/health` do not get logged at INFO. + """ + + isLeaf = 1 + + def render_GET(self, request): + request.setHeader(b"Content-Type", b"text/plain") + return b"OK" diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py new file mode 100644 index 0000000000..2d021f6565 --- /dev/null +++ b/tests/rest/test_health.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.rest.health import HealthResource + +from tests import unittest + + +class HealthCheckTests(unittest.HomeserverTestCase): + def setUp(self): + super().setUp() + + # replace the JsonResource with a HealthResource. + self.resource = HealthResource() + + def test_health(self): + request, channel = self.make_request("GET", "/health", shorthand=False) + self.render(request) + + self.assertEqual(request.code, 200) + self.assertEqual(channel.result["body"], b"OK") -- cgit 1.5.1 From f3fe6961b211d898aa347771df598c531fbca90c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Aug 2020 12:17:17 -0400 Subject: Convert additional database stores to async/await (#8045) --- changelog.d/8045.misc | 1 + synapse/storage/databases/main/client_ips.py | 54 +++++----- synapse/storage/databases/main/search.py | 69 ++++++------- synapse/storage/databases/main/signatures.py | 7 +- synapse/storage/databases/main/user_directory.py | 124 ++++++++--------------- tests/storage/test_user_directory.py | 4 +- 6 files changed, 107 insertions(+), 152 deletions(-) create mode 100644 changelog.d/8045.misc (limited to 'tests') diff --git a/changelog.d/8045.misc b/changelog.d/8045.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8045.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 50d71f5ebc..216a5925fc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -14,8 +14,7 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore @@ -82,21 +81,19 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): "devices_last_seen", self._devices_last_seen_update ) - @defer.inlineCallbacks - def _remove_user_ip_nonunique(self, progress, batch_size): + async def _remove_user_ip_nonunique(self, progress, batch_size): def f(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() - yield self.db_pool.runWithConnection(f) - yield self.db_pool.updates._end_background_update( + await self.db_pool.runWithConnection(f) + await self.db_pool.updates._end_background_update( "user_ips_drop_nonunique_index" ) return 1 - @defer.inlineCallbacks - def _analyze_user_ip(self, progress, batch_size): + async def _analyze_user_ip(self, progress, batch_size): # Background update to analyze user_ips table before we run the # deduplication background update. The table may not have been analyzed # for ages due to the table locks. @@ -106,14 +103,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): def user_ips_analyze(txn): txn.execute("ANALYZE user_ips") - yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) + await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) - yield self.db_pool.updates._end_background_update("user_ips_analyze") + await self.db_pool.updates._end_background_update("user_ips_analyze") return 1 - @defer.inlineCallbacks - def _remove_user_ip_dupes(self, progress, batch_size): + async def _remove_user_ip_dupes(self, progress, batch_size): # This works function works by scanning the user_ips table in batches # based on `last_seen`. For each row in a batch it searches the rest of # the table to see if there are any duplicates, if there are then they @@ -140,7 +136,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return None # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.db_pool.runInteraction( + end_last_seen = await self.db_pool.runInteraction( "user_ips_dups_get_last_seen", get_last_seen ) @@ -275,15 +271,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) - yield self.db_pool.runInteraction("user_ips_dups_remove", remove) + await self.db_pool.runInteraction("user_ips_dups_remove", remove) if last: - yield self.db_pool.updates._end_background_update("user_ips_remove_dupes") + await self.db_pool.updates._end_background_update("user_ips_remove_dupes") return batch_size - @defer.inlineCallbacks - def _devices_last_seen_update(self, progress, batch_size): + async def _devices_last_seen_update(self, progress, batch_size): """Background update to insert last seen info into devices table """ @@ -346,12 +341,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return len(rows) - updated = yield self.db_pool.runInteraction( + updated = await self.db_pool.runInteraction( "_devices_last_seen_update", _devices_last_seen_update_txn ) if not updated: - yield self.db_pool.updates._end_background_update("devices_last_seen") + await self.db_pool.updates._end_background_update("devices_last_seen") return updated @@ -460,25 +455,25 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Failed to upsert, log and continue logger.error("Failed to insert client IP %r: %r", entry, e) - @defer.inlineCallbacks - def get_last_client_ip_by_device(self, user_id, device_id): + async def get_last_client_ip_by_device( + self, user_id: str, device_id: Optional[str] + ) -> Dict[Tuple[str, str], dict]: """For each device_id listed, give the user_ip it was last seen on Args: - user_id (str) - device_id (str): If None fetches all devices for the user + user_id: The user to fetch devices for. + device_id: If None fetches all devices for the user Returns: - defer.Deferred: resolves to a dict, where the keys - are (user_id, device_id) tuples. The values are also dicts, with - keys giving the column names + A dictionary mapping a tuple of (user_id, device_id) to dicts, with + keys giving the column names from the devices table. """ keyvalues = {"user_id": user_id} if device_id is not None: keyvalues["device_id"] = device_id - res = yield self.db_pool.simple_select_list( + res = await self.db_pool.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -500,8 +495,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): } return ret - @defer.inlineCallbacks - def get_user_ip_and_agents(self, user): + async def get_user_ip_and_agents(self, user): user_id = user.to_string() results = {} @@ -511,7 +505,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 2162d0712d..7f8d1880e5 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -16,8 +16,7 @@ import logging import re from collections import namedtuple - -from twisted.internet import defer +from typing import List, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -114,8 +113,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) - @defer.inlineCallbacks - def _background_reindex_search(self, progress, batch_size): + async def _background_reindex_search(self, progress, batch_size): # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -206,19 +204,18 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): return len(event_search_rows) - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_SEARCH_UPDATE_NAME ) return result - @defer.inlineCallbacks - def _background_reindex_gin_search(self, progress, batch_size): + async def _background_reindex_gin_search(self, progress, batch_size): """This handles old synapses which used GIST indexes, if any; converting them back to be GIN as per the actual schema. """ @@ -255,15 +252,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): - yield self.db_pool.runWithConnection(create_index) + await self.db_pool.runWithConnection(create_index) - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME ) return 1 - @defer.inlineCallbacks - def _background_reindex_search_order(self, progress, batch_size): + async def _background_reindex_search_order(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -288,12 +284,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): ) conn.set_session(autocommit=False) - yield self.db_pool.runWithConnection(create_index) + await self.db_pool.runWithConnection(create_index) pg = dict(progress) pg["have_added_indexes"] = True - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self.db_pool.updates._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, @@ -331,12 +327,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): return len(rows), True - num_rows, finished = yield self.db_pool.runInteraction( + num_rows, finished = await self.db_pool.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn ) if not finished: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_SEARCH_ORDER_UPDATE_NAME ) @@ -347,8 +343,7 @@ class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(SearchStore, self).__init__(database, db_conn, hs) - @defer.inlineCallbacks - def search_msgs(self, room_ids, search_term, keys): + async def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. Args: @@ -425,7 +420,7 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self.db_pool.execute( + results = await self.db_pool.execute( "search_msgs", self.db_pool.cursor_to_dict, sql, *args ) @@ -433,7 +428,7 @@ class SearchStore(SearchBackgroundUpdateStore): # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) @@ -442,11 +437,11 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" - count_results = yield self.db_pool.execute( + count_results = await self.db_pool.execute( "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args ) @@ -462,19 +457,25 @@ class SearchStore(SearchBackgroundUpdateStore): "count": count, } - @defer.inlineCallbacks - def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): + async def search_rooms( + self, + room_ids: List[str], + search_term: str, + keys: List[str], + limit, + pagination_token: Optional[str] = None, + ) -> List[dict]: """Performs a full text search over events with given keys. Args: - room_id (list): The room_ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports - "content.body", "content.name", "content.topic" - pagination_token (str): A pagination token previously returned + room_ids: The room_ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports "content.body", + "content.name", "content.topic" + pagination_token: A pagination token previously returned Returns: - list of dicts + Each match as a dictionary. """ clauses = [] @@ -577,7 +578,7 @@ class SearchStore(SearchBackgroundUpdateStore): args.append(limit) - results = yield self.db_pool.execute( + results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args ) @@ -585,7 +586,7 @@ class SearchStore(SearchBackgroundUpdateStore): # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) @@ -594,11 +595,11 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" - count_results = yield self.db_pool.execute( + count_results = await self.db_pool.execute( "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args ) diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index dae8e8bd29..be191dd870 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -15,8 +15,6 @@ from unpaddedbase64 import encode_base64 -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList @@ -40,9 +38,8 @@ class SignatureWorkerStore(SQLBaseStore): return self.db_pool.runInteraction("get_event_reference_hashes", f) - @defer.inlineCallbacks - def add_event_hashes(self, event_ids): - hashes = yield self.get_event_reference_hashes(event_ids) + async def add_event_hashes(self, event_ids): + hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} for e_id, h in hashes.items() diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index d73a8e8ab9..af21fe457a 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -16,8 +16,6 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool from synapse.storage.databases.main.state import StateFilter @@ -59,8 +57,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) - @defer.inlineCallbacks - def _populate_user_directory_createtables(self, progress, batch_size): + async def _populate_user_directory_createtables(self, progress, batch_size): # Get all the rooms that we want to process. def _make_staging_area(txn): @@ -102,45 +99,43 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) - new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.db_pool.runInteraction( + new_pos = await self.get_max_stream_id_in_current_state_deltas() + await self.db_pool.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self.db_pool.simple_insert( + await self.db_pool.simple_insert( TEMP_TABLE + "_position", {"position": new_pos} ) - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_createtables" ) return 1 - @defer.inlineCallbacks - def _populate_user_directory_cleanup(self, progress, batch_size): + async def _populate_user_directory_cleanup(self, progress, batch_size): """ Update the user directory stream position, then clean up the old tables. """ - position = yield self.db_pool.simple_select_one_onecol( + position = await self.db_pool.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) - yield self.update_user_directory_stream_pos(position) + await self.update_user_directory_stream_pos(position) def _delete_staging_area(txn): txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory_cleanup", _delete_staging_area ) - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_cleanup" ) return 1 - @defer.inlineCallbacks - def _populate_user_directory_process_rooms(self, progress, batch_size): + async def _populate_user_directory_process_rooms(self, progress, batch_size): """ Args: progress (dict) @@ -151,7 +146,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # If we don't have progress filed, delete everything. if not progress: - yield self.delete_all_from_user_dir() + await self.delete_all_from_user_dir() def _get_next_batch(txn): # Only fetch 250 rooms, so we don't fetch too many at once, even @@ -176,13 +171,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return rooms_to_work_on - rooms_to_work_on = yield self.db_pool.runInteraction( + rooms_to_work_on = await self.db_pool.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_rooms" ) return 1 @@ -195,21 +190,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): processed_event_count = 0 for room_id, event_count in rooms_to_work_on: - is_in_room = yield self.is_host_joined(room_id, self.server_name) + is_in_room = await self.is_host_joined(room_id, self.server_name) if is_in_room: - is_public = yield self.is_room_world_readable_or_publicly_joinable( + is_public = await self.is_room_world_readable_or_publicly_joinable( room_id ) - users_with_profile = yield defer.ensureDeferred( - state.get_current_users_in_room(room_id) - ) + users_with_profile = await state.get_current_users_in_room(room_id) user_ids = set(users_with_profile) # Update each user in the user directory. for user_id, profile in users_with_profile.items(): - yield self.update_profile_in_user_dir( + await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -223,7 +216,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): to_insert.add(user_id) if to_insert: - yield self.add_users_in_public_rooms(room_id, to_insert) + await self.add_users_in_public_rooms(room_id, to_insert) to_insert.clear() else: for user_id in user_ids: @@ -243,22 +236,22 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # If it gets too big, stop and write to the database # to prevent storing too much in RAM. if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET: - yield self.add_users_who_share_private_room( + await self.add_users_who_share_private_room( room_id, to_insert ) to_insert.clear() if to_insert: - yield self.add_users_who_share_private_room(room_id, to_insert) + await self.add_users_who_share_private_room(room_id, to_insert) to_insert.clear() # We've finished a room. Delete it from the table. - yield self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( TEMP_TABLE + "_rooms", {"room_id": room_id} ) # Update the remaining counter. progress["remaining"] -= 1 - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory", self.db_pool.updates._background_update_progress_txn, "populate_user_directory_process_rooms", @@ -273,13 +266,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return processed_event_count - @defer.inlineCallbacks - def _populate_user_directory_process_users(self, progress, batch_size): + async def _populate_user_directory_process_users(self, progress, batch_size): """ If search_all_users is enabled, add all of the users to the user directory. """ if not self.hs.config.user_directory_search_all_users: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_users" ) return 1 @@ -305,13 +297,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return users_to_work_on - users_to_work_on = yield self.db_pool.runInteraction( + users_to_work_on = await self.db_pool.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more users -- complete the transaction. if not users_to_work_on: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_users" ) return 1 @@ -322,18 +314,18 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) for user_id in users_to_work_on: - profile = yield self.get_profileinfo(get_localpart_from_id(user_id)) - yield self.update_profile_in_user_dir( + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) # We've finished processing a user. Delete it from the table. - yield self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( TEMP_TABLE + "_users", {"user_id": user_id} ) # Update the remaining counter. progress["remaining"] -= 1 - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory", self.db_pool.updates._background_update_progress_txn, "populate_user_directory_process_users", @@ -342,8 +334,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return len(users_to_work_on) - @defer.inlineCallbacks - def is_room_world_readable_or_publicly_joinable(self, room_id): + async def is_room_world_readable_or_publicly_joinable(self, room_id): """Check if the room is either world_readable or publically joinable """ @@ -353,20 +344,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = yield self.get_filtered_current_state_ids( + current_state_ids = await self.get_filtered_current_state_ids( room_id, StateFilter.from_types(types_to_filter) ) join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) if join_rules_id: - join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) + join_rule_ev = await self.get_event(join_rules_id, allow_none=True) if join_rule_ev: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: return True hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) if hist_vis_id: - hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) + hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) if hist_vis_ev: if hist_vis_ev.content.get("history_visibility") == "world_readable": return True @@ -590,19 +581,18 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): "remove_from_user_dir", _remove_from_user_dir_txn ) - @defer.inlineCallbacks - def get_users_in_dir_due_to_room(self, room_id): + async def get_users_in_dir_due_to_room(self, room_id): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self.db_pool.simple_select_onecol( + user_ids_share_pub = await self.db_pool.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self.db_pool.simple_select_onecol( + user_ids_share_priv = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -645,8 +635,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): "remove_user_who_share_room", _remove_user_who_share_room_txn ) - @defer.inlineCallbacks - def get_user_dir_rooms_user_is_in(self, user_id): + async def get_user_dir_rooms_user_is_in(self, user_id): """ Returns the rooms that a user is in. @@ -656,14 +645,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns: list: user_id """ - rows = yield self.db_pool.simple_select_onecol( + rows = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self.db_pool.simple_select_onecol( + pub_rows = await self.db_pool.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -674,32 +663,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - @defer.inlineCallbacks - def get_rooms_in_common_for_users(self, user_id, other_user_id): - """Given two user_ids find out the list of rooms they share. - """ - sql = """ - SELECT room_id FROM ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) AS f1 INNER JOIN ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) f2 USING (room_id) - """ - - rows = yield self.db_pool.execute( - "get_rooms_in_common_for_users", None, sql, user_id, other_user_id - ) - - return [room_id for room_id, in rows] - def get_user_directory_stream_pos(self): return self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", @@ -708,8 +671,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): desc="get_user_directory_stream_pos", ) - @defer.inlineCallbacks - def search_user_dir(self, user_id, search_term, limit): + async def search_user_dir(self, user_id, search_term, limit): """Searches for users in directory Returns: @@ -806,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self.db_pool.execute( + results = await self.db_pool.execute( "search_user_dir", self.db_pool.cursor_to_dict, sql, *args ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 6a545d2eb0..ecfafe68a9 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -40,7 +40,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): def test_search_user_dir(self): # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. - r = yield self.store.search_user_dir(ALICE, "bob", 10) + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(1, len(r["results"])) self.assertDictEqual( @@ -51,7 +51,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): def test_search_user_dir_all_users(self): self.hs.config.user_directory_search_all_users = True try: - r = yield self.store.search_user_dir(ALICE, "bob", 10) + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(2, len(r["results"])) self.assertDictEqual( -- cgit 1.5.1 From 7f837959ea25ef50b3675c9c2596ef42592dc127 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Aug 2020 13:36:29 -0400 Subject: Convert directory, e2e_room_keys, end_to_end_keys, monthly_active_users database to async (#8042) --- changelog.d/8042.misc | 1 + synapse/storage/databases/main/devices.py | 12 ++-- synapse/storage/databases/main/directory.py | 51 ++++++++------- synapse/storage/databases/main/e2e_room_keys.py | 30 +++++---- synapse/storage/databases/main/end_to_end_keys.py | 73 +++++++++++----------- .../storage/databases/main/monthly_active_users.py | 31 ++++----- tests/handlers/test_appservice.py | 2 +- tests/storage/test_directory.py | 32 +++++++--- tests/storage/test_end_to_end_keys.py | 12 ++-- tests/storage/test_monthly_active_users.py | 17 +++-- 10 files changed, 141 insertions(+), 120 deletions(-) create mode 100644 changelog.d/8042.misc (limited to 'tests') diff --git a/changelog.d/8042.misc b/changelog.d/8042.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8042.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 81e64de126..7a5f0bab05 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -136,7 +136,9 @@ class DeviceWorkerStore(SQLBaseStore): master_key_by_user = {} self_signing_key_by_user = {} for user in users: - cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + cross_signing_key = yield defer.ensureDeferred( + self.get_e2e_cross_signing_key(user, "master") + ) if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( cross_signing_key @@ -149,8 +151,8 @@ class DeviceWorkerStore(SQLBaseStore): "device_id": verify_key.version, } - cross_signing_key = yield self.get_e2e_cross_signing_key( - user, "self_signing" + cross_signing_key = yield defer.ensureDeferred( + self.get_e2e_cross_signing_key(user, "self_signing") ) if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( @@ -246,7 +248,7 @@ class DeviceWorkerStore(SQLBaseStore): destination (str): The host the device updates are intended for from_stream_id (int): The minimum stream_id to filter updates by, exclusive query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping - user_id/device_id to update stream_id and the relevent json-encoded + user_id/device_id to update stream_id and the relevant json-encoded opentracing context Returns: @@ -599,7 +601,7 @@ class DeviceWorkerStore(SQLBaseStore): between the requested tokens due to the limit. The token returned can be used in a subsequent call to this - function to get further updatees. + function to get further updates. The updates are a list of 2-tuples of stream ID and the row data """ diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 7819bfcbb3..037e02603c 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -14,30 +14,29 @@ # limitations under the License. from collections import namedtuple -from typing import Optional - -from twisted.internet import defer +from typing import Iterable, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore +from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) class DirectoryWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): - """ Get's the room_id and server list for a given room_alias + async def get_association_from_room_alias( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: + """Gets the room_id and server list for a given room_alias Args: - room_alias (RoomAlias) + room_alias: The alias to translate to an ID. Returns: - Deferred: results in namedtuple with keys "room_id" and - "servers" or None if no association can be found + The room alias mapping or None if no association can be found. """ - room_id = yield self.db_pool.simple_select_one_onecol( + room_id = await self.db_pool.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self.db_pool.simple_select_onecol( + servers = await self.db_pool.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -79,18 +78,20 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): - @defer.inlineCallbacks - def create_room_alias_association(self, room_alias, room_id, servers, creator=None): + async def create_room_alias_association( + self, + room_alias: RoomAlias, + room_id: str, + servers: Iterable[str], + creator: Optional[str] = None, + ) -> None: """ Creates an association between a room alias and room_id/servers Args: - room_alias (RoomAlias) - room_id (str) - servers (list) - creator (str): Optional user_id of creator. - - Returns: - Deferred + room_alias: The alias to create. + room_id: The target of the alias. + servers: A list of servers through which it may be possible to join the room + creator: Optional user_id of creator. """ def alias_txn(txn): @@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore): ) try: - ret = yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "create_room_alias_association", alias_txn ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() ) - return ret - @defer.inlineCallbacks - def delete_room_alias(self, room_alias): - room_id = yield self.db_pool.runInteraction( + async def delete_room_alias(self, room_alias: RoomAlias) -> str: + room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) return room_id - def _delete_room_alias_txn(self, txn, room_alias): + def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str: txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),), diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index c4aaec3993..2eeb9f97dc 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json @@ -23,8 +21,9 @@ from synapse.util import json_encoder class EndToEndRoomKeyStore(SQLBaseStore): - @defer.inlineCallbacks - def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + async def update_e2e_room_key( + self, user_id, version, room_id, session_id, room_key + ): """Replaces the encrypted E2E room key for a given session in a given backup Args: @@ -37,7 +36,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -54,8 +53,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="update_e2e_room_key", ) - @defer.inlineCallbacks - def add_e2e_room_keys(self, user_id, version, room_keys): + async def add_e2e_room_keys(self, user_id, version, room_keys): """Bulk add room keys to a given backup. Args: @@ -88,13 +86,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): } ) - yield self.db_pool.simple_insert_many( + await self.db_pool.simple_insert_many( table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace - @defer.inlineCallbacks - def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. @@ -109,7 +106,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): the backup (or for the specified room) Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -124,7 +121,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -242,8 +239,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - @defer.inlineCallbacks - def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. @@ -258,7 +256,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): the backup (or for the specified room) Returns: - A deferred of the deletion transaction + The deletion transaction """ keyvalues = {"user_id": user_id, "version": int(version)} @@ -267,7 +265,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 6126376a6f..f93e0d320d 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,12 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection -from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json @@ -31,8 +30,7 @@ from synapse.util.iterutils import batch_iter class EndToEndKeyWorkerStore(SQLBaseStore): @trace - @defer.inlineCallbacks - def get_e2e_device_keys( + async def get_e2e_device_keys( self, query_list, include_all_devices=False, include_deleted_devices=False ): """Fetch a list of device keys. @@ -52,7 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): if not query_list: return {} - results = yield self.db_pool.runInteraction( + results = await self.db_pool.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, @@ -175,8 +173,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): log_kv(result) return result - @defer.inlineCallbacks - def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + async def get_e2e_one_time_keys( + self, user_id: str, device_id: str, key_ids: List[str] + ) -> Dict[Tuple[str, str], str]: """Retrieve a number of one-time keys for a user Args: @@ -186,11 +185,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): retrieve Returns: - deferred resolving to Dict[(str, str), str]: map from (algorithm, - key_id) to json string for key + A map from (algorithm, key_id) to json string for key """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -202,17 +200,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore): log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) return result - @defer.inlineCallbacks - def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + async def add_e2e_one_time_keys( + self, + user_id: str, + device_id: str, + time_now: int, + new_keys: Iterable[Tuple[str, str, str]], + ) -> None: """Insert some new one time keys for a device. Errors if any of the keys already exist. Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - time_now(long): insertion time to record (ms since epoch) - new_keys(iterable[(str, str, str)]: keys to add - each a tuple of - (algorithm, key_id, key json) + user_id: id of user to get keys for + device_id: id of device to get keys for + time_now: insertion time to record (ms since epoch) + new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ def _add_e2e_one_time_keys(txn): @@ -242,7 +244,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @@ -269,22 +271,23 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) - @defer.inlineCallbacks - def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + async def get_e2e_cross_signing_key( + self, user_id: str, key_type: str, from_user_id: Optional[str] = None + ) -> Optional[dict]: """Returns a user's cross-signing key. Args: - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being requested: either 'master' + user_id: the user whose key is being requested + key_type: the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on + from_user_id: if specified, signatures made by this user on the self-signing key will be included in the result Returns: dict of the key data or None if not found """ - res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) + res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) user_keys = res.get(user_id) if not user_keys: return None @@ -450,28 +453,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return keys - @defer.inlineCallbacks - def get_e2e_cross_signing_keys_bulk( - self, user_ids: List[str], from_user_id: str = None - ) -> defer.Deferred: + async def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: Optional[str] = None + ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. Args: - user_ids (list[str]): the users whose keys are being requested - from_user_id (str): if specified, signatures made by this user on + user_ids: the users whose keys are being requested + from_user_id: if specified, signatures made by this user on the self-signing keys will be included in the result Returns: - Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to - key data. If a user's cross-signing keys were not found, either - their user ID will not be in the dict, or their user ID will map - to None. + A map of user ID to key type to key data. If a user's cross-signing + keys were not found, either their user ID will not be in the dict, + or their user ID will map to None. """ - result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_e2e_cross_signing_signatures", self._get_e2e_cross_signing_signatures_txn, result, diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 02b01d9619..e71cdd2cb4 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -15,8 +15,6 @@ import logging from typing import List -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached @@ -252,16 +250,12 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "reap_monthly_active_users", _reap_users, reserved_users ) - @defer.inlineCallbacks - def upsert_monthly_active_user(self, user_id): + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server Args: - user_id (str): user to add/update - - Returns: - Deferred + user_id: user to add/update """ # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of @@ -271,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # _initialise_reserved_users reasoning that it would be very strange to # include a support user in this context. - is_support = yield self.is_support_user(user_id) + is_support = await self.is_support_user(user_id) if is_support: return - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) @@ -322,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): return is_insert - @defer.inlineCallbacks - def populate_monthly_active_users(self, user_id): + async def populate_monthly_active_users(self, user_id): """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables @@ -332,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): """ if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = yield self.is_guest(user_id) + is_guest = await self.is_guest(user_id) if is_guest: return - is_trial = yield self.is_trial_user(user_id) + is_trial = await self.is_trial_user(user_id) if is_trial: return - last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) + last_seen_timestamp = await self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() # We want to reduce to the total number of db writes, and are happy @@ -352,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # False, there is no point in checking get_monthly_active_count - it # adds no value and will break the logic if max_mau_value is exceeded. if not self._limit_usage_by_mau: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) else: - count = yield self.get_monthly_active_count() + count = await self.get_monthly_active_count() if count < self._max_mau_value: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 628f7d8db0..2a0b7c1b56 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -120,7 +120,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.query_alias.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_association_from_room_alias.return_value = defer.succeed( + self.mock_store.get_association_from_room_alias.return_value = make_awaitable( Mock(room_id=room_id, servers=servers) ) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 4e128e1047..daac947cb2 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -34,8 +34,10 @@ class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_room_to_alias(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) self.assertEquals( @@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_alias_to_room(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) self.assertObjectHasAttributes( {"room_id": self.room.to_string(), "servers": ["test"]}, - (yield self.store.get_association_from_room_alias(self.alias)), + ( + yield defer.ensureDeferred( + self.store.get_association_from_room_alias(self.alias) + ) + ), ) @defer.inlineCallbacks def test_delete_alias(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) - room_id = yield self.store.delete_room_alias(self.alias) + room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias)) self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - (yield self.store.get_association_from_room_alias(self.alias)) + ( + yield defer.ensureDeferred( + self.store.get_association_from_room_alias(self.alias) + ) + ) ) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 398d546280..9f8d30373b 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -34,7 +34,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user", "device", now, json) - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -63,7 +65,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user", "device", now, json) yield self.store.store_device("user", "device", "display_name") - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -85,8 +89,8 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) - res = yield self.store.get_e2e_device_keys( - (("user1", "device1"), ("user2", "device2")) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) ) self.assertIn("user1", res) self.assertIn("device1", res["user1"]) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 259f2215f1..e793781a26 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 @@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) @@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( @@ -333,7 +340,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.get_success(self.store.populate_monthly_active_users("@user:sever")) -- cgit 1.5.1 From fcbab08cbd46d28976411b1d014a4efb76c8b7a4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 10 Aug 2020 12:29:47 +0100 Subject: Add an assertion on prev_events in create_new_client_event (#8041) I think this would have caught all the cases in https://github.com/matrix-org/synapse/issues/7642 - and I think a 500 makes more sense here than a 403 --- changelog.d/8041.misc | 1 + synapse/handlers/message.py | 9 +++++++++ tests/storage/test_redaction.py | 4 ++++ 3 files changed, 14 insertions(+) create mode 100644 changelog.d/8041.misc (limited to 'tests') diff --git a/changelog.d/8041.misc b/changelog.d/8041.misc new file mode 100644 index 0000000000..eefa98d744 --- /dev/null +++ b/changelog.d/8041.misc @@ -0,0 +1 @@ +Add an assertion on prev_events in create_new_client_event. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 708533d4d1..8ddded8389 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -768,6 +768,15 @@ class EventCreationHandler(object): else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) + # we now ought to have some prev_events (unless it's a create event). + # + # do a quick sanity check here, rather than waiting until we've created the + # event and then try to auth it (which fails with a somewhat confusing "No + # create event in auth events") + assert ( + builder.type == EventTypes.Create or len(prev_event_ids) > 0 + ), "Attempting to create an event with no prev_events" + event = await builder.build(prev_event_ids=prev_event_ids) context = await self.state.compute_event_context(event) if requester: diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 41511d479f..1ea35d60c1 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -251,6 +251,10 @@ class RedactionTestCase(unittest.HomeserverTestCase): def room_id(self): return self._base_builder.room_id + @property + def type(self): + return self._base_builder.type + event_1, context_1 = self.get_success( self.event_creation_handler.create_new_client_event( EventIdManglingBuilder( -- cgit 1.5.1 From a0acdfa9e93ae63a3adee264d5420fdd1d38d76e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 11 Aug 2020 17:21:13 -0400 Subject: Converts event_federation and registration databases to async/await (#8061) --- changelog.d/8061.misc | 1 + synapse/storage/databases/main/event_federation.py | 38 ++-- synapse/storage/databases/main/registration.py | 233 ++++++++++----------- synapse/storage/databases/state/bg_updates.py | 18 +- tests/handlers/test_register.py | 11 +- tests/storage/test_monthly_active_users.py | 8 +- tests/storage/test_registration.py | 18 +- 7 files changed, 150 insertions(+), 177 deletions(-) create mode 100644 changelog.d/8061.misc (limited to 'tests') diff --git a/changelog.d/8061.misc b/changelog.d/8061.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8061.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index eddb32b4d3..484875f989 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -15,9 +15,7 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Dict, List, Optional, Set, Tuple - -from twisted.internet import defer +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process @@ -286,17 +284,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return dict(txn) - @defer.inlineCallbacks - def get_max_depth_of(self, event_ids): + async def get_max_depth_of(self, event_ids: List[str]) -> int: """Returns the max depth of a set of event IDs Args: - event_ids (list[str]) - - Returns - Deferred[int] + event_ids: The event IDs to calculate the max depth of. """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -550,9 +544,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return event_results - @defer.inlineCallbacks - def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.db_pool.runInteraction( + async def get_missing_events(self, room_id, earliest_events, latest_events, limit): + ids = await self.db_pool.runInteraction( "get_missing_events", self._get_missing_events, room_id, @@ -560,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = yield self.get_events_as_list(ids) + events = await self.get_events_as_list(ids) return events def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): @@ -595,17 +588,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_results.reverse() return event_results - @defer.inlineCallbacks - def get_successor_events(self, event_ids): + async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]: """Fetch all events that have the given events as a prev event Args: - event_ids (iterable[str]) - - Returns: - Deferred[list[str]] + event_ids: The events to use as the previous events. """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -674,8 +663,7 @@ class EventFederationStore(EventFederationWorkerStore): txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - @defer.inlineCallbacks - def _background_delete_non_state_event_auth(self, progress, batch_size): + async def _background_delete_non_state_event_auth(self, progress, batch_size): def delete_event_auth(txn): target_min_stream_id = progress.get("target_min_stream_id_inclusive") max_stream_id = progress.get("max_stream_id_exclusive") @@ -714,12 +702,12 @@ class EventFederationStore(EventFederationWorkerStore): return min_stream_id >= target_min_stream_id - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_AUTH_STATE_ONLY, delete_event_auth ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_AUTH_STATE_ONLY ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index f618629e09..402ae25571 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,9 +17,8 @@ import logging import re -from typing import Optional +from typing import Dict, List, Optional -from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.api.constants import UserTypes @@ -30,7 +29,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 @@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_user_by_id", ) - @defer.inlineCallbacks - def is_trial_user(self, user_id): + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config Args: - user_id (str) - - Returns: - Deferred[bool] + user_id: The user to check for trial status. """ - info = yield self.get_user_by_id(user_id) + info = await self.get_user_by_id(user_id) if not info: return False @@ -105,41 +100,42 @@ class RegistrationWorkerStore(SQLBaseStore): "get_user_by_access_token", self._query_for_auth, token ) - @cachedInlineCallbacks() - def get_expiration_ts_for_user(self, user_id): + @cached() + async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]: """Get the expiration timestamp for the account bearing a given user ID. Args: - user_id (str): The ID of the user. + user_id: The ID of the user. Returns: - defer.Deferred: None, if the account has no expiration timestamp, - otherwise int representation of the timestamp (as a number of - milliseconds since epoch). + None, if the account has no expiration timestamp, otherwise int + representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", allow_none=True, desc="get_expiration_ts_for_user", ) - return res - @defer.inlineCallbacks - def set_account_validity_for_user( - self, user_id, expiration_ts, email_sent, renewal_token=None - ): + async def set_account_validity_for_user( + self, + user_id: str, + expiration_ts: int, + email_sent: bool, + renewal_token: Optional[str] = None, + ) -> None: """Updates the account validity properties of the given account, with the given values. Args: - user_id (str): ID of the account to update properties for. - expiration_ts (int): New expiration date, as a timestamp in milliseconds + user_id: ID of the account to update properties for. + expiration_ts: New expiration date, as a timestamp in milliseconds since epoch. - email_sent (bool): True means a renewal email has been sent for this - account and there's no need to send another one for the current validity + email_sent: True means a renewal email has been sent for this account + and there's no need to send another one for the current validity period. - renewal_token (str): Renewal token the user can use to extend the validity + renewal_token: Renewal token the user can use to extend the validity of their account. Defaults to no token. """ @@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore): txn, self.get_expiration_ts_for_user, (user_id,) ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_account_validity_for_user", set_account_validity_for_user_txn ) - @defer.inlineCallbacks - def set_renewal_token_for_user(self, user_id, renewal_token): + async def set_renewal_token_for_user( + self, user_id: str, renewal_token: str + ) -> None: """Defines a renewal token for a given user. Args: - user_id (str): ID of the user to set the renewal token for. - renewal_token (str): Random unique string that will be used to renew the + user_id: ID of the user to set the renewal token for. + renewal_token: Random unique string that will be used to renew the user's account. Raises: StoreError: The provided token is already set for another user. """ - yield self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, desc="set_renewal_token_for_user", ) - @defer.inlineCallbacks - def get_user_from_renewal_token(self, renewal_token): + async def get_user_from_renewal_token(self, renewal_token: str) -> str: """Get a user ID from a renewal token. Args: - renewal_token (str): The renewal token to perform the lookup with. + renewal_token: The renewal token to perform the lookup with. Returns: - defer.Deferred[str]: The ID of the user to which the token belongs. + The ID of the user to which the token belongs. """ - res = yield self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", desc="get_user_from_renewal_token", ) - return res - - @defer.inlineCallbacks - def get_renewal_token_for_user(self, user_id): + async def get_renewal_token_for_user(self, user_id: str) -> str: """Get the renewal token associated with a given user ID. Args: - user_id (str): The user ID to lookup a token for. + user_id: The user ID to lookup a token for. Returns: - defer.Deferred[str]: The renewal token associated with this user ID. + The renewal token associated with this user ID. """ - res = yield self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", desc="get_renewal_token_for_user", ) - return res - - @defer.inlineCallbacks - def get_users_expiring_soon(self): + async def get_users_expiring_soon(self) -> List[Dict[str, int]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] + A list of dictionaries mapping user ID to expiration time (in milliseconds). """ def select_users_txn(txn, now_ms, renew_at): @@ -238,53 +228,49 @@ class RegistrationWorkerStore(SQLBaseStore): txn.execute(sql, values) return self.db_pool.cursor_to_dict(txn) - res = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, self.clock.time_msec(), self.config.account_validity.renew_at, ) - return res - - @defer.inlineCallbacks - def set_renewal_mail_status(self, user_id, email_sent): + async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: """Sets or unsets the flag that indicates whether a renewal email has been sent to the user (and the user hasn't renewed their account yet). Args: - user_id (str): ID of the user to set/unset the flag for. - email_sent (bool): Flag which indicates whether a renewal email has been sent + user_id: ID of the user to set/unset the flag for. + email_sent: Flag which indicates whether a renewal email has been sent to this user. """ - yield self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, desc="set_renewal_mail_status", ) - @defer.inlineCallbacks - def delete_account_validity_for_user(self, user_id): + async def delete_account_validity_for_user(self, user_id: str) -> None: """Deletes the entry for the given user in the account validity table, removing their expiration date and renewal token. Args: - user_id (str): ID of the user to remove from the account validity table. + user_id: ID of the user to remove from the account validity table. """ - yield self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", ) - async def is_server_admin(self, user): + async def is_server_admin(self, user: UserID) -> bool: """Determines if a user is an admin of this homeserver. Args: - user (UserID): user ID of the user to test + user: user ID of the user to test - Returns (bool): + Returns: true iff the user is a server admin, false otherwise. """ res = await self.db_pool.simple_select_one_onecol( @@ -332,32 +318,31 @@ class RegistrationWorkerStore(SQLBaseStore): return None - @cachedInlineCallbacks() - def is_real_user(self, user_id): + @cached() + async def is_real_user(self, user_id: str) -> bool: """Determines if the user is a real user, ie does not have a 'user_type'. Args: - user_id (str): user id to test + user_id: user id to test Returns: - Deferred[bool]: True if user 'user_type' is null or empty string + True if user 'user_type' is null or empty string """ - res = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "is_real_user", self.is_real_user_txn, user_id ) - return res @cached() - def is_support_user(self, user_id): + async def is_support_user(self, user_id: str) -> bool: """Determines if the user is of type UserTypes.SUPPORT Args: - user_id (str): user id to test + user_id: user id to test Returns: - Deferred[bool]: True if user is of type UserTypes.SUPPORT + True if user is of type UserTypes.SUPPORT """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) @@ -413,8 +398,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_user_by_external_id", ) - @defer.inlineCallbacks - def count_all_users(self): + async def count_all_users(self): """Counts all users registered on the homeserver.""" def _count_users(txn): @@ -424,8 +408,7 @@ class RegistrationWorkerStore(SQLBaseStore): return rows[0]["users"] return 0 - ret = yield self.db_pool.runInteraction("count_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_users", _count_users) def count_daily_user_type(self): """ @@ -460,8 +443,7 @@ class RegistrationWorkerStore(SQLBaseStore): "count_daily_user_type", _count_daily_user_type ) - @defer.inlineCallbacks - def count_nonbridged_users(self): + async def count_nonbridged_users(self): def _count_users(txn): txn.execute( """ @@ -472,11 +454,9 @@ class RegistrationWorkerStore(SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction("count_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_users", _count_users) - @defer.inlineCallbacks - def count_real_users(self): + async def count_real_users(self): """Counts all users without a special user_type registered on the homeserver.""" def _count_users(txn): @@ -486,8 +466,7 @@ class RegistrationWorkerStore(SQLBaseStore): return rows[0]["users"] return 0 - ret = yield self.db_pool.runInteraction("count_real_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_real_users", _count_users) async def generate_user_id(self) -> str: """Generate a suitable localpart for a guest user @@ -537,23 +516,20 @@ class RegistrationWorkerStore(SQLBaseStore): return ret["user_id"] return None - @defer.inlineCallbacks - def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self.db_pool.simple_upsert( + async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - @defer.inlineCallbacks - def user_get_threepids(self, user_id): - ret = yield self.db_pool.simple_select_list( + async def user_get_threepids(self, user_id): + return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], "user_get_threepids", ) - return ret def user_delete_threepid(self, user_id, medium, address): return self.db_pool.simple_delete( @@ -668,18 +644,18 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_id_servers_user_bound", ) - @cachedInlineCallbacks() - def get_user_deactivated_status(self, user_id): + @cached() + async def get_user_deactivated_status(self, user_id: str) -> bool: """Retrieve the value for the `deactivated` property for the provided user. Args: - user_id (str): The ID of the user to retrieve the status for. + user_id: The ID of the user to retrieve the status for. Returns: - defer.Deferred(bool): The requested value. + True if the user was deactivated, false if the user is still active. """ - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -818,8 +794,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - @defer.inlineCallbacks - def _background_update_set_deactivated_flag(self, progress, batch_size): + async def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. """ @@ -870,19 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): else: return False, len(rows) - end, nb_processed = yield self.db_pool.runInteraction( + end, nb_processed = await self.db_pool.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) if end: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "users_set_deactivated_flag" ) return nb_processed - @defer.inlineCallbacks - def _bg_user_threepids_grandfather(self, progress, batch_size): + async def _bg_user_threepids_grandfather(self, progress, batch_size): """We now track which identity servers a user binds their 3PID to, so we need to handle the case of existing bindings where we didn't track this. @@ -903,11 +877,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): txn.executemany(sql, [(id_server,) for id_server in id_servers]) if id_servers: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) - yield self.db_pool.updates._end_background_update("user_threepids_grandfather") + await self.db_pool.updates._end_background_update("user_threepids_grandfather") return 1 @@ -937,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): + async def add_access_token_to_user( + self, + user_id: str, + token: str, + device_id: Optional[str], + valid_until_ms: Optional[int], + ) -> None: """Adds an access token for the given user. Args: - user_id (str): The user ID. - token (str): The new access token to add. - device_id (str): ID of the device to associate with the access - token - valid_until_ms (int|None): when the token is valid until. None for - no expiry. + user_id: The user ID. + token: The new access token to add. + device_id: ID of the device to associate with the access token + valid_until_ms: when the token is valid until. None for no expiry. Raises: StoreError if there was a problem adding this. """ next_id = self._access_tokens_id_gen.get_next() - yield self.db_pool.simple_insert( + await self.db_pool.simple_insert( "access_tokens", { "id": next_id, @@ -1097,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - txn.call_after(self.is_guest.invalidate, (user_id,)) def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str @@ -1241,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return self.db_pool.runInteraction("delete_access_token", f) - @cachedInlineCallbacks() - def is_guest(self, user_id): - res = yield self.db_pool.simple_select_one_onecol( + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1481,16 +1457,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self.clock.time_msec(), ) - @defer.inlineCallbacks - def set_user_deactivated_status(self, user_id, deactivated): + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: """Set the `deactivated` property for the provided user to the provided value. Args: - user_id (str): The ID of the user to set the status for. - deactivated (bool): The value to set for `deactivated`. + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. """ - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, user_id, @@ -1507,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + txn.call_after(self.is_guest.invalidate, (user_id,)) - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): + async def _set_expiration_date_when_missing(self): """ Retrieves the list of registered users that don't have an expiration date, and adds an expiration date for each of them. @@ -1533,7 +1510,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, user["name"], use_delta=True ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "get_users_with_no_expiration_date", select_users_with_no_expiration_date_txn, ) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 1e2d584098..139085b672 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine @@ -198,8 +196,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): columns=["room_id"], ) - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): + async def _background_deduplicate_state(self, progress, batch_size): """This background update will slowly deduplicate state by reencoding them as deltas. """ @@ -212,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self.db_pool.execute( + rows = await self.db_pool.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -330,19 +327,18 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): return False, batch_size - finished, result = yield self.db_pool.runInteraction( + finished, result = await self.db_pool.runInteraction( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) if finished: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) return result * BATCH_SIZE_SCALE_FACTOR - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): + async def _background_index_state(self, progress, batch_size): def reindex_txn(conn): conn.rollback() if isinstance(self.database_engine, PostgresEngine): @@ -365,9 +361,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): ) txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - yield self.db_pool.runWithConnection(reindex_txn) + await self.db_pool.runWithConnection(reindex_txn) - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.STATE_GROUP_INDEX_UPDATE_NAME ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 6d45c4b233..e364b1bd62 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -22,6 +22,7 @@ from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler from synapse.types import RoomAlias, UserID, create_requester +from tests.test_utils import make_awaitable from tests.unittest import override_config from .. import unittest @@ -187,7 +188,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=defer.succeed(False)) + self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -199,8 +200,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): room_alias_str = "#room:test" - self.store.count_real_users = Mock(return_value=defer.succeed(1)) - self.store.is_real_user = Mock(return_value=defer.succeed(True)) + self.store.count_real_users = Mock(return_value=make_awaitable(1)) + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -214,8 +215,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=defer.succeed(2)) - self.store.is_real_user = Mock(return_value=defer.succeed(True)) + self.store.count_real_users = Mock(return_value=make_awaitable(2)) + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index e793781a26..9870c74883 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -300,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.register_user(user_id=user2, password_hash=None)) now = int(self.hs.get_clock().time_msec()) - self.store.user_add_threepid(user1, "email", user1_email, now, now) - self.store.user_add_threepid(user2, "email", user2_email, now, now) + self.get_success( + self.store.user_add_threepid(user1, "email", user1_email, now, now) + ) + self.get_success( + self.store.user_add_threepid(user2, "email", user2_email, now, now) + ) users = self.get_success(self.store.get_registered_reserved_users()) self.assertEqual(len(users), len(threepids)) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 71a40a0a49..840db66072 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -58,8 +58,10 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + ) ) result = yield self.store.get_user_by_access_token(self.tokens[1]) @@ -74,11 +76,15 @@ class RegistrationStoreTestCase(unittest.TestCase): def test_user_delete_access_tokens(self): # add some tokens yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + ) ) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + ) ) # now delete some -- cgit 1.5.1 From 04faa0bfa960d9f0dc60e9cf4ec270221249b7ca Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 11 Aug 2020 17:21:20 -0400 Subject: Convert tags and metrics databases to async/await (#8062) --- changelog.d/8062.misc | 1 + synapse/storage/databases/main/metrics.py | 20 ++-- synapse/storage/databases/main/tags.py | 103 +++++++++++---------- .../test_resource_limits_server_notices.py | 5 +- 4 files changed, 64 insertions(+), 65 deletions(-) create mode 100644 changelog.d/8062.misc (limited to 'tests') diff --git a/changelog.d/8062.misc b/changelog.d/8062.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8062.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index baa7a5092a..686052bd83 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -15,8 +15,6 @@ import typing from collections import Counter -from twisted.internet import defer - from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore @@ -69,8 +67,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): res = await self.db_pool.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = Counter([x[0] for x in res]) - @defer.inlineCallbacks - def count_daily_messages(self): + async def count_daily_messages(self): """ Returns an estimate of the number of messages sent in the last day. @@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction("count_messages", _count_messages) - return ret + return await self.db_pool.runInteraction("count_messages", _count_messages) - @defer.inlineCallbacks - def count_daily_sent_messages(self): + async def count_daily_sent_messages(self): def _count_messages(txn): # This is good enough as if you have silly characters in your own # hostname then thats your own fault. @@ -109,13 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_daily_sent_messages", _count_messages ) - return ret - @defer.inlineCallbacks - def count_daily_active_rooms(self): + async def count_daily_active_rooms(self): def _count(txn): sql = """ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events @@ -126,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count) - return ret + return await self.db_pool.runInteraction("count_daily_active_rooms", _count) diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index eedd2d96c3..e4e0a0c433 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,14 +15,13 @@ # limitations under the License. import logging -from typing import List, Tuple +from typing import Dict, List, Tuple from canonicaljson import json -from twisted.internet import defer - from synapse.storage._base import db_to_json from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -30,30 +29,26 @@ logger = logging.getLogger(__name__) class TagsWorkerStore(AccountDataWorkerStore): @cached() - def get_tags_for_user(self, user_id): + async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: """Get all the tags for a user. Args: - user_id(str): The user to get the tags for. + user_id: The user to get the tags for. Returns: - A deferred dict mapping from room_id strings to dicts mapping from - tag strings to tag content. + A mapping from room_id strings to dicts mapping from tag strings to + tag content. """ - deferred = self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) - @deferred.addCallback - def tags_by_room(rows): - tags_by_room = {} - for row in rows: - room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = db_to_json(row["content"]) - return tags_by_room - - return deferred + tags_by_room = {} + for row in rows: + room_tags = tags_by_room.setdefault(row["room_id"], {}) + room_tags[row["tag"]] = db_to_json(row["content"]) + return tags_by_room async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -127,17 +122,19 @@ class TagsWorkerStore(AccountDataWorkerStore): return results, upto_token, limited - @defer.inlineCallbacks - def get_updated_tags(self, user_id, stream_id): + async def get_updated_tags( + self, user_id: str, stream_id: int + ) -> Dict[str, List[str]]: """Get all the tags for the rooms where the tags have changed since the given version Args: user_id(str): The user to get the tags for. stream_id(int): The earliest update to get for the user. + Returns: - A deferred dict mapping from room_id strings to lists of tag - strings for all the rooms that changed since the stream_id token. + A mapping from room_id strings to lists of tag strings for all the + rooms that changed since the stream_id token. """ def get_updated_tags_txn(txn): @@ -155,47 +152,53 @@ class TagsWorkerStore(AccountDataWorkerStore): if not changed: return {} - room_ids = yield self.db_pool.runInteraction( + room_ids = await self.db_pool.runInteraction( "get_updated_tags", get_updated_tags_txn ) results = {} if room_ids: - tags_by_room = yield self.get_tags_for_user(user_id) + tags_by_room = await self.get_tags_for_user(user_id) for room_id in room_ids: results[room_id] = tags_by_room.get(room_id, {}) return results - def get_tags_for_room(self, user_id, room_id): + async def get_tags_for_room( + self, user_id: str, room_id: str + ) -> Dict[str, JsonDict]: """Get all the tags for the given room + Args: - user_id(str): The user to get tags for - room_id(str): The room to get tags for + user_id: The user to get tags for + room_id: The room to get tags for + Returns: - A deferred list of string tags. + A mapping of tags to tag content. """ - return self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), desc="get_tags_for_room", - ).addCallback( - lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows} ) + return {row["tag"]: db_to_json(row["content"]) for row in rows} class TagsStore(TagsWorkerStore): - @defer.inlineCallbacks - def add_tag_to_room(self, user_id, room_id, tag, content): + async def add_tag_to_room( + self, user_id: str, room_id: str, tag: str, content: JsonDict + ) -> int: """Add a tag to a room for a user. + Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - tag(str): The tag name to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + room_id: The room to add a tag for. + tag: The tag name to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the tag has been added. + The next account data ID. """ content_json = json.dumps(content) @@ -209,18 +212,17 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) + await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def remove_tag_from_room(self, user_id, room_id, tag): + async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int: """Remove a tag from a room for a user. + Returns: - A deferred that completes once the tag has been removed + The next account data ID. """ def remove_tag_txn(txn, next_id): @@ -232,21 +234,22 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) + await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - def _update_revision_txn(self, txn, user_id, room_id, next_id): + def _update_revision_txn( + self, txn, user_id: str, room_id: str, next_id: int + ) -> None: """Update the latest revision of the tags for the given user and room. Args: txn: The database cursor - user_id(str): The ID of the user. - room_id(str): The ID of the room. - next_id(int): The the revision to advance to. + user_id: The ID of the user. + room_id: The ID of the room. + next_id: The the revision to advance to. """ txn.call_after( diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 3f88abe3d2..2858d13558 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,6 +27,7 @@ from synapse.server_notices.resource_limits_server_notices import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import default_config @@ -79,7 +80,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=defer.succeed("!something:localhost") ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) + self._rlsn._store.get_tags_for_room = Mock( + side_effect=lambda user_id, room_id: make_awaitable({}) + ) @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self): -- cgit 1.5.1 From a3a59bab7bb3b69dcfc5620e6f3ac51af3f0f965 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 09:28:48 -0400 Subject: Convert appservice, group server, profile and more databases to async (#8066) --- changelog.d/8066.misc | 1 + synapse/storage/databases/main/appservice.py | 34 ++++------ synapse/storage/databases/main/filtering.py | 8 +-- synapse/storage/databases/main/group_server.py | 86 ++++++++++++-------------- synapse/storage/databases/main/presence.py | 7 +-- synapse/storage/databases/main/profile.py | 21 +++---- synapse/storage/databases/main/relations.py | 19 +++--- synapse/storage/databases/main/transactions.py | 7 +-- tests/storage/test_appservice.py | 24 +++---- 9 files changed, 91 insertions(+), 116 deletions(-) create mode 100644 changelog.d/8066.misc (limited to 'tests') diff --git a/changelog.d/8066.misc b/changelog.d/8066.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8066.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 055a3962dc..5cf1a88399 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -18,8 +18,6 @@ import re from canonicaljson import json -from twisted.internet import defer - from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore, db_to_json @@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - @defer.inlineCallbacks - def get_appservices_by_state(self, state): + async def get_appservices_by_state(self, state): """Get a list of application services based on their state. Args: state(ApplicationServiceState): The state to filter on. Returns: - A Deferred which resolves to a list of ApplicationServices, which - may be empty. + A list of ApplicationServices, which may be empty. """ - results = yield self.db_pool.simple_select_list( + results = await self.db_pool.simple_select_list( "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -147,16 +143,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - @defer.inlineCallbacks - def get_appservice_state(self, service): + async def get_appservice_state(self, service): """Get the application service state. Args: service(ApplicationService): The service whose state to set. Returns: - A Deferred which resolves to ApplicationServiceState. + An ApplicationServiceState. """ - result = yield self.db_pool.simple_select_one( + result = await self.db_pool.simple_select_one( "application_services_state", {"as_id": service.id}, ["state"], @@ -270,16 +265,14 @@ class ApplicationServiceTransactionWorkerStore( "complete_appservice_txn", _complete_appservice_txn ) - @defer.inlineCallbacks - def get_oldest_unsent_txn(self, service): + async def get_oldest_unsent_txn(self, service): """Get the oldest transaction which has not been sent for this service. Args: service(ApplicationService): The app service to get the oldest txn. Returns: - A Deferred which resolves to an AppServiceTransaction or - None. + An AppServiceTransaction or None. """ def _get_oldest_unsent_txn(txn): @@ -298,7 +291,7 @@ class ApplicationServiceTransactionWorkerStore( return entry - entry = yield self.db_pool.runInteraction( + entry = await self.db_pool.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) @@ -307,7 +300,7 @@ class ApplicationServiceTransactionWorkerStore( event_ids = db_to_json(entry["event_ids"]) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) @@ -332,8 +325,7 @@ class ApplicationServiceTransactionWorkerStore( "set_appservice_last_pos", set_appservice_last_pos_txn ) - @defer.inlineCallbacks - def get_new_events_for_appservice(self, current_id, limit): + async def get_new_events_for_appservice(self, current_id, limit): """Get all new evnets""" def get_new_events_for_appservice_txn(txn): @@ -357,11 +349,11 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db_pool.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index cae6bda80e..45a1760170 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,12 +17,12 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached class FilteringStore(SQLBaseStore): - @cachedInlineCallbacks(num_args=2) - def get_user_filter(self, user_localpart, filter_id): + @cached(num_args=2) + async def get_user_filter(self, user_localpart, filter_id): # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self.db_pool.simple_select_one_onecol( + def_json = await self.db_pool.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 75ea6d4b2f..380db3a3f3 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,12 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple - -from twisted.internet import defer +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util import json_encoder # The category ID for the "default" category. We don't store as null in the @@ -210,9 +209,8 @@ class GroupServerWorkerStore(SQLBaseStore): "get_rooms_for_summary", _get_rooms_for_summary_txn ) - @defer.inlineCallbacks - def get_group_categories(self, group_id): - rows = yield self.db_pool.simple_select_list( + async def get_group_categories(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -227,9 +225,8 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_group_category(self, group_id, category_id): - category = yield self.db_pool.simple_select_one( + async def get_group_category(self, group_id, category_id): + category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -240,9 +237,8 @@ class GroupServerWorkerStore(SQLBaseStore): return category - @defer.inlineCallbacks - def get_group_roles(self, group_id): - rows = yield self.db_pool.simple_select_list( + async def get_group_roles(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -257,9 +253,8 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_group_role(self, group_id, role_id): - role = yield self.db_pool.simple_select_one( + async def get_group_role(self, group_id, role_id): + role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -448,12 +443,11 @@ class GroupServerWorkerStore(SQLBaseStore): "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - @defer.inlineCallbacks - def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation(self, group_id, user_id): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -499,13 +493,13 @@ class GroupServerWorkerStore(SQLBaseStore): "get_all_groups_for_user", _get_all_groups_for_user_txn ) - def get_groups_changes_for_user(self, user_id, from_token, to_token): + async def get_groups_changes_for_user(self, user_id, from_token, to_token): from_token = int(from_token) has_changed = self._group_updates_stream_cache.has_entity_changed( user_id, from_token ) if not has_changed: - return defer.succeed([]) + return [] def _get_groups_changes_for_user_txn(txn): sql = """ @@ -525,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore): for group_id, membership, gtype, content_json in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) @@ -1087,31 +1081,31 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_group_publicity", ) - @defer.inlineCallbacks - def register_user_group_membership( + async def register_user_group_membership( self, - group_id, - user_id, - membership, - is_admin=False, - content={}, - local_attestation=None, - remote_attestation=None, - is_publicised=False, - ): + group_id: str, + user_id: str, + membership: str, + is_admin: bool = False, + content: JsonDict = {}, + local_attestation: Optional[dict] = None, + remote_attestation: Optional[dict] = None, + is_publicised: bool = False, + ) -> int: """Registers that a local user is a member of a (local or remote) group. Args: - group_id (str) - user_id (str) - membership (str) - is_admin (bool) - content (dict): Content of the membership, e.g. includes the inviter + group_id: The group the member is being added to. + user_id: THe user ID to add to the group. + membership: The type of group membership. + is_admin: Whether the user should be added as a group admin. + content: Content of the membership, e.g. includes the inviter if the user has been invited. - local_attestation (dict): If remote group then store the fact that we + local_attestation: If remote group then store the fact that we have given out an attestation, else None. - remote_attestation (dict): If remote group then store the remote + remote_attestation: If remote group then store the remote attestation from the group, else None. + is_publicised: Whether this should be publicised. """ def _register_user_group_membership_txn(txn, next_id): @@ -1188,18 +1182,17 @@ class GroupServerStore(GroupServerWorkerStore): return next_id with self._group_updates_id_gen.get_next() as next_id: - res = yield self.db_pool.runInteraction( + res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, ) return res - @defer.inlineCallbacks - def create_group( + async def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description - ): - yield self.db_pool.simple_insert( + ) -> None: + await self.db_pool.simple_insert( table="groups", values={ "group_id": group_id, @@ -1212,9 +1205,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="create_group", ) - @defer.inlineCallbacks - def update_group_profile(self, group_id, profile): - yield self.db_pool.simple_update_one( + async def update_group_profile(self, group_id, profile): + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 99e66dc6e9..59ba12820a 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -15,8 +15,6 @@ from typing import List, Tuple -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cached, cachedList @@ -24,14 +22,13 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): - @defer.inlineCallbacks - def update_presence(self, presence_states): + async def update_presence(self, presence_states): stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) with stream_ordering_manager as stream_orderings: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 4a4f2cb385..b8261357d4 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart): try: - profile = yield self.db_pool.simple_select_one( + profile = await self.db_pool.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -118,14 +115,13 @@ class ProfileStore(ProfileWorkerStore): desc="update_remote_profile_cache", ) - @defer.inlineCallbacks - def maybe_delete_remote_profile_cache(self, user_id): + async def maybe_delete_remote_profile_cache(self, user_id): """Check if we still care about the remote user's profile, and if we don't then remove their profile from the cache """ - subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + subscribed = await self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -151,11 +147,10 @@ class ProfileStore(ProfileWorkerStore): _get_remote_profile_cache_entries_that_expire_txn, ) - @defer.inlineCallbacks - def is_subscribed_remote_profile_for_user(self, user_id): + async def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +161,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b81f1449b7..a9ceffc20e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,10 +14,12 @@ # limitations under the License. import logging +from typing import Optional import attr from synapse.api.constants import RelationTypes +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( @@ -25,7 +27,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -227,18 +229,18 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) - @cachedInlineCallbacks() - def get_applicable_edit(self, event_id): + @cached() + async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. Correctly handles checking whether edits were allowed to happen. Args: - event_id (str): The original event ID + event_id: The original event ID Returns: - Deferred[EventBase|None]: Returns the most recent edit, if any. + The most recent edit, if any. """ # We only allow edits for `m.room.message` events that have the same sender @@ -268,15 +270,14 @@ class RelationsWorkerStore(SQLBaseStore): if row: return row[0] - edit_id = yield self.db_pool.runInteraction( + edit_id = await self.db_pool.runInteraction( "get_applicable_edit", _get_applicable_edit_txn ) if not edit_id: - return + return None - edit_event = yield self.get_event(edit_id, allow_none=True) - return edit_event + return await self.get_event(edit_id, allow_none=True) def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): """Check if a user has already annotated an event with the same key diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 8804c0e4ac..52668dbdf9 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -18,8 +18,6 @@ from collections import namedtuple from canonicaljson import encode_canonical_json -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool @@ -126,8 +124,7 @@ class TransactionStore(SQLBaseStore): desc="set_received_txn_response", ) - @defer.inlineCallbacks - def get_destination_retry_timings(self, destination): + async def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. Args: @@ -142,7 +139,7 @@ class TransactionStore(SQLBaseStore): if result is not SENTINEL: return result - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1b516b7976..98b74890d5 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -178,14 +178,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_appservice_state_none(self): service = Mock(id="999") - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(None, state) @defer.inlineCallbacks def test_get_appservice_state_up(self): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) service = Mock(id=self.as_list[0]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.UP, state) @defer.inlineCallbacks @@ -194,13 +194,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) service = Mock(id=self.as_list[1]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.DOWN, state) @defer.inlineCallbacks def test_get_appservices_by_state_none(self): - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) @@ -339,7 +339,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): def test_get_oldest_unsent_txn_none(self): service = Mock(id=self.as_list[0]["id"]) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) @defer.inlineCallbacks @@ -349,14 +349,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=events) + self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) yield self._insert_txn(service.id, 11, other_events) yield self._insert_txn(service.id, 12, other_events) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) self.assertEquals(events, txn.events) @@ -366,8 +366,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(1, len(services)) self.assertEquals(self.as_list[0]["id"], services[0].id) @@ -379,8 +379,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(2, len(services)) self.assertEquals( -- cgit 1.5.1 From d68e10f308f89810e8d9ff94219cc68ca83f636d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 09:29:06 -0400 Subject: Convert account data, device inbox, and censor events databases to async/await (#8063) --- changelog.d/8063.misc | 1 + synapse/storage/databases/main/account_data.py | 77 +++++++++++--------- synapse/storage/databases/main/censor_events.py | 11 ++- synapse/storage/databases/main/deviceinbox.py | 94 +++++++++++++------------ tests/handlers/test_typing.py | 3 +- 5 files changed, 99 insertions(+), 87 deletions(-) create mode 100644 changelog.d/8063.misc (limited to 'tests') diff --git a/changelog.d/8063.misc b/changelog.d/8063.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8063.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index cf039e7f7d..82aac2bbf3 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -16,15 +16,16 @@ import abc import logging -from typing import List, Tuple +from typing import List, Optional, Tuple from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -97,13 +98,15 @@ class AccountDataWorkerStore(SQLBaseStore): "get_account_data_for_user", get_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2, max_entries=5000) - def get_global_account_data_by_type_for_user(self, data_type, user_id): + @cached(num_args=2, max_entries=5000) + async def get_global_account_data_by_type_for_user( + self, data_type: str, user_id: str + ) -> Optional[JsonDict]: """ Returns: - Deferred: A dict + The account data. """ - result = yield self.db_pool.simple_select_one_onecol( + result = await self.db_pool.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -280,9 +283,11 @@ class AccountDataWorkerStore(SQLBaseStore): "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) - def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): - ignored_account_data = yield self.get_global_account_data_by_type_for_user( + @cached(num_args=2, cache_context=True, max_entries=5000) + async def is_ignored_by( + self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext + ) -> bool: + ignored_account_data = await self.get_global_account_data_by_type_for_user( "m.ignored_user_list", ignorer_user_id, on_invalidate=cache_context.invalidate, @@ -307,24 +312,27 @@ class AccountDataStore(AccountDataWorkerStore): super(AccountDataStore, self).__init__(database, db_conn, hs) - def get_max_account_data_stream_id(self): + def get_max_account_data_stream_id(self) -> int: """Get the current max stream id for the private user data stream Returns: - A deferred int. + The maximum stream ID. """ return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def add_account_data_to_room(self, user_id, room_id, account_data_type, content): + async def add_account_data_to_room( + self, user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> int: """Add some account_data to a room for a user. + Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + room_id: The room to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the account_data has been added. + The maximum stream ID. """ content_json = json_encoder.encode(content) @@ -332,7 +340,7 @@ class AccountDataStore(AccountDataWorkerStore): # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -350,7 +358,7 @@ class AccountDataStore(AccountDataWorkerStore): # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. - yield self._update_max_stream_id(next_id) + await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) @@ -359,18 +367,20 @@ class AccountDataStore(AccountDataWorkerStore): (user_id, room_id, account_data_type), content ) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def add_account_data_for_user(self, user_id, account_data_type, content): + async def add_account_data_for_user( + self, user_id: str, account_data_type: str, content: JsonDict + ) -> int: """Add some account_data to a room for a user. + Args: - user_id(str): The user to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the account_data has been added. + The maximum stream ID. """ content_json = json_encoder.encode(content) @@ -378,7 +388,7 @@ class AccountDataStore(AccountDataWorkerStore): # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, @@ -396,7 +406,7 @@ class AccountDataStore(AccountDataWorkerStore): # Note: This is only here for backwards compat to allow admins to # roll back to a previous Synapse version. Next time we update the # database version we can remove this table. - yield self._update_max_stream_id(next_id) + await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) @@ -404,14 +414,13 @@ class AccountDataStore(AccountDataWorkerStore): (account_data_type, user_id) ) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - def _update_max_stream_id(self, next_id): + def _update_max_stream_id(self, next_id: int): """Update the max stream_id Args: - next_id(int): The the revision to advance to. + next_id: The the revision to advance to. """ # Note: This is only here for backwards compat to allow admins to diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 1de8249563..f211ddbaf8 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -16,8 +16,6 @@ import logging from typing import TYPE_CHECKING -from twisted.internet import defer - from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore @@ -148,17 +146,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase updatevalues={"json": pruned_json}, ) - @defer.inlineCallbacks - def expire_event(self, event_id): + async def expire_event(self, event_id: str) -> None: """Retrieve and expire an event that has expired, and delete its associated expiry timestamp. If the event can't be retrieved, delete its associated timestamp so we don't try to expire it again in the future. Args: - event_id (str): The ID of the event to delete. + event_id: The ID of the event to delete. """ # Try to retrieve the event's content from the database or the event cache. - event = yield self.get_event(event_id) + event = await self.get_event(event_id) def delete_expired_event_txn(txn): # Delete the expiry timestamp associated with this event from the database. @@ -193,7 +190,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase txn, "_get_event_cache", (event.event_id,) ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_expired_event", delete_expired_event_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 76ec954f44..1f6e995c4f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -16,8 +16,6 @@ import logging from typing import List, Tuple -from twisted.internet import defer - from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool @@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() - def get_new_messages_for_device( - self, user_id, device_id, last_stream_id, current_stream_id, limit=100 - ): + async def get_new_messages_for_device( + self, + user_id: str, + device_id: str, + last_stream_id: int, + current_stream_id: int, + limit: int = 100, + ) -> Tuple[List[dict], int]: """ Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - current_stream_id(int): The current position of the to device + user_id: The recipient user_id. + device_id: The recipient device_id. + last_stream_id: The last stream ID checked. + current_stream_id: The current position of the to device message stream. + limit: The maximum number of messages to retrieve. + Returns: - Deferred ([dict], int): List of messages for the device and where - in the stream the messages got to. + A list of messages for the device and where in the stream the messages got to. """ has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_stream_id ) if not has_changed: - return defer.succeed(([], current_stream_id)) + return ([], current_stream_id) def get_new_messages_for_device_txn(txn): sql = ( @@ -69,20 +74,22 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn ) @trace - @defer.inlineCallbacks - def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + async def delete_messages_for_device( + self, user_id: str, device_id: str, up_to_stream_id: int + ) -> int: """ Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - up_to_stream_id(int): Where to delete messages up to. + user_id: The recipient user_id. + device_id: The recipient device_id. + up_to_stream_id: Where to delete messages up to. + Returns: - A deferred that resolves to the number of messages deleted. + The number of messages deleted. """ # If we have cached the last stream id we've deleted up to, we can # check if there is likely to be anything that needs deleting @@ -109,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - count = yield self.db_pool.runInteraction( + count = await self.db_pool.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) @@ -128,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): return count @trace - def get_new_device_msgs_for_remote( + async def get_new_device_msgs_for_remote( self, destination, last_stream_id, current_stream_id, limit - ): + ) -> Tuple[List[dict], int]: """ Args: destination(str): The name of the remote server. @@ -139,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): current_stream_id(int|long): The current position of the device message stream. Returns: - Deferred ([dict], int|long): List of messages for the device and where - in the stream the messages got to. + A list of messages for the device and where in the stream the messages got to. """ set_tag("destination", destination) @@ -153,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) - return defer.succeed(([], current_stream_id)) + return ([], current_stream_id) if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. - return defer.succeed(([], last_stream_id)) + return ([], last_stream_id) @trace def get_new_messages_for_remote_destination_txn(txn): @@ -178,7 +184,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @@ -290,16 +296,15 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) - @defer.inlineCallbacks - def _background_drop_index_device_inbox(self, progress, batch_size): + async def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() - yield self.db_pool.runWithConnection(reindex_txn) + await self.db_pool.runWithConnection(reindex_txn) - yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) return 1 @@ -320,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) ) @trace - @defer.inlineCallbacks - def add_messages_to_device_inbox( - self, local_messages_by_user_then_device, remote_messages_by_destination - ): + async def add_messages_to_device_inbox( + self, + local_messages_by_user_then_device: dict, + remote_messages_by_destination: dict, + ) -> int: """Used to send messages from this server. Args: - sender_user_id(str): The ID of the user sending these messages. - local_messages_by_user_and_device(dict): + local_messages_by_user_and_device: Dictionary of user_id to device_id to message. - remote_messages_by_destination(dict): + remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. + Returns: - A deferred stream_id that resolves when the messages have been - inserted. + The new stream_id. """ def add_messages_txn(txn, now_ms, stream_id): @@ -359,7 +364,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): @@ -371,10 +376,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) return self._device_inbox_id_gen.get_current_token() - @defer.inlineCallbacks - def add_messages_from_remote_to_device_inbox( - self, origin, message_id, local_messages_by_user_then_device - ): + async def add_messages_from_remote_to_device_inbox( + self, origin: str, message_id: str, local_messages_by_user_then_device: dict + ) -> int: def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our @@ -409,7 +413,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index b7d0adb10e..64ddd8243d 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,6 +24,7 @@ from synapse.api.errors import AuthError from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import register_federation_servlets @@ -151,7 +152,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( ([], 0) ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None -- cgit 1.5.1 From 5dd73d029eff32668b3ca69b7fb8529fc7c58745 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2020 15:05:50 +0100 Subject: Add type hints to handlers.message and events.builder (#8067) --- changelog.d/8067.misc | 1 + mypy.ini | 3 ++ synapse/events/builder.py | 58 ++++++++++++++++++++----------------- synapse/handlers/message.py | 22 ++++++++------ synapse/handlers/room_member.py | 12 +++++--- tests/rest/client/test_retention.py | 4 ++- tox.ini | 2 ++ 7 files changed, 61 insertions(+), 41 deletions(-) create mode 100644 changelog.d/8067.misc (limited to 'tests') diff --git a/changelog.d/8067.misc b/changelog.d/8067.misc new file mode 100644 index 0000000000..f4404b7506 --- /dev/null +++ b/changelog.d/8067.misc @@ -0,0 +1 @@ +Add type hints to `synapse.handlers.message` and `synapse.events.builder`. diff --git a/mypy.ini b/mypy.ini index a61009b197..c69cb5dc40 100644 --- a/mypy.ini +++ b/mypy.ini @@ -81,3 +81,6 @@ ignore_missing_imports = True [mypy-rust_python_jaeger_reporter.*] ignore_missing_imports = True + +[mypy-nacl.*] +ignore_missing_imports = True diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 4e179d49b3..9ed24380dd 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -17,6 +17,7 @@ from typing import Optional import attr from nacl.signing import SigningKey +from synapse.api.auth import Auth from synapse.api.constants import MAX_DEPTH from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( @@ -27,6 +28,8 @@ from synapse.api.room_versions import ( ) from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict +from synapse.state import StateHandler +from synapse.storage.databases.main import DataStore from synapse.types import EventID, JsonDict from synapse.util import Clock from synapse.util.stringutils import random_string @@ -42,45 +45,46 @@ class EventBuilder(object): Attributes: room_version: Version of the target room - room_id (str) - type (str) - sender (str) - content (dict) - unsigned (dict) - internal_metadata (_EventInternalMetadata) - - _state (StateHandler) - _auth (synapse.api.Auth) - _store (DataStore) - _clock (Clock) - _hostname (str): The hostname of the server creating the event + room_id + type + sender + content + unsigned + internal_metadata + + _state + _auth + _store + _clock + _hostname: The hostname of the server creating the event _signing_key: The signing key to use to sign the event as the server """ - _state = attr.ib() - _auth = attr.ib() - _store = attr.ib() - _clock = attr.ib() - _hostname = attr.ib() - _signing_key = attr.ib() + _state = attr.ib(type=StateHandler) + _auth = attr.ib(type=Auth) + _store = attr.ib(type=DataStore) + _clock = attr.ib(type=Clock) + _hostname = attr.ib(type=str) + _signing_key = attr.ib(type=SigningKey) room_version = attr.ib(type=RoomVersion) - room_id = attr.ib() - type = attr.ib() - sender = attr.ib() + room_id = attr.ib(type=str) + type = attr.ib(type=str) + sender = attr.ib(type=str) - content = attr.ib(default=attr.Factory(dict)) - unsigned = attr.ib(default=attr.Factory(dict)) + content = attr.ib(default=attr.Factory(dict), type=JsonDict) + unsigned = attr.ib(default=attr.Factory(dict), type=JsonDict) # These only exist on a subset of events, so they raise AttributeError if # someone tries to get them when they don't exist. - _state_key = attr.ib(default=None) - _redacts = attr.ib(default=None) - _origin_server_ts = attr.ib(default=None) + _state_key = attr.ib(default=None, type=Optional[str]) + _redacts = attr.ib(default=None, type=Optional[str]) + _origin_server_ts = attr.ib(default=None, type=Optional[int]) internal_metadata = attr.ib( - default=attr.Factory(lambda: _EventInternalMetadata({})) + default=attr.Factory(lambda: _EventInternalMetadata({})), + type=_EventInternalMetadata, ) @property diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8ddded8389..2643438e84 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from canonicaljson import encode_canonical_json, json @@ -93,11 +93,11 @@ class MessageHandler(object): async def get_room_data( self, - user_id: str = None, - room_id: str = None, - event_type: Optional[str] = None, - state_key: str = "", - is_guest: bool = False, + user_id: str, + room_id: str, + event_type: str, + state_key: str, + is_guest: bool, ) -> dict: """ Get data from a room. @@ -407,7 +407,7 @@ class EventCreationHandler(object): # # map from room id to time-of-last-attempt. # - self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int] + self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int] # we need to construct a ConsentURIBuilder here, as it checks that the necessary # config options, but *only* if we have a configuration for which we are @@ -707,7 +707,7 @@ class EventCreationHandler(object): async def create_and_send_nonmember_event( self, requester: Requester, - event_dict: EventBase, + event_dict: dict, ratelimit: bool = True, txn_id: Optional[str] = None, ) -> Tuple[EventBase, int]: @@ -971,7 +971,7 @@ class EventCreationHandler(object): # Validate a newly added alias or newly added alt_aliases. original_alias = None - original_alt_aliases = set() + original_alt_aliases = [] # type: List[str] original_event_id = event.unsigned.get("replaces_state") if original_event_id: @@ -1019,6 +1019,10 @@ class EventCreationHandler(object): current_state_ids = await context.get_current_state_ids() + # We know this event is not an outlier, so this must be + # non-None. + assert current_state_ids is not None + state_to_include_ids = [ e_id for k, e_id in current_state_ids.items() diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8e409f24e8..31705cdbdb 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -16,7 +16,7 @@ import abc import logging from http import HTTPStatus -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 @@ -37,6 +37,10 @@ from synapse.util.distributor import user_joined_room, user_left_room from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) @@ -48,7 +52,7 @@ class RoomMemberHandler(object): __metaclass__ = abc.ABCMeta - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -207,7 +211,7 @@ class RoomMemberHandler(object): return duplicate.event_id, stream_id stream_id = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit + requester, event, context, extra_users=[target], ratelimit=ratelimit, ) prev_state_ids = await context.get_prev_state_ids() @@ -1000,7 +1004,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): check_complexity = self.hs.config.limit_remote_rooms.enabled if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: - check_complexity = not await self.hs.auth.is_server_admin(user) + check_complexity = not await self.auth.is_server_admin(user) if check_complexity: # Fetch the room complexity diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index e54ffea150..0b191d13c6 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -144,7 +144,9 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Get the create event to, later, check that we can still access it. message_handler = self.hs.get_message_handler() create_event = self.get_success( - message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + message_handler.get_room_data( + self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False + ) ) # Send a first event to the room. This is the event we'll want to be purged at the diff --git a/tox.ini b/tox.ini index 45e129580f..e5413eb110 100644 --- a/tox.ini +++ b/tox.ini @@ -179,6 +179,7 @@ commands = mypy \ synapse/appservice \ synapse/config \ synapse/event_auth.py \ + synapse/events/builder.py \ synapse/events/spamcheck.py \ synapse/federation \ synapse/handlers/auth.py \ @@ -186,6 +187,7 @@ commands = mypy \ synapse/handlers/directory.py \ synapse/handlers/federation.py \ synapse/handlers/identity.py \ + synapse/handlers/message.py \ synapse/handlers/oidc_handler.py \ synapse/handlers/presence.py \ synapse/handlers/room_member.py \ -- cgit 1.5.1 From 5ecc8b58255d7e33ad63a6c931efa6ed5e41ad01 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 10:51:42 -0400 Subject: Convert devices database to async/await. (#8069) --- changelog.d/8069.misc | 1 + synapse/storage/databases/main/devices.py | 333 ++++++++++++++++-------------- tests/handlers/test_typing.py | 2 +- tests/storage/test_devices.py | 44 ++-- tests/storage/test_end_to_end_keys.py | 16 +- 5 files changed, 220 insertions(+), 176 deletions(-) create mode 100644 changelog.d/8069.misc (limited to 'tests') diff --git a/changelog.d/8069.misc b/changelog.d/8069.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8069.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 7a5f0bab05..2b33060480 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Optional, Set, Tuple - -from twisted.internet import defer +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -33,14 +31,9 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.types import Collection, get_verify_key_from_cross_signing_key +from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_encoder -from synapse.util.caches.descriptors import ( - Cache, - cached, - cachedInlineCallbacks, - cachedList, -) +from synapse.util.caches.descriptors import Cache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -54,13 +47,13 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id, device_id): + def get_device(self, user_id: str, device_id: str): """Retrieve a device. Only returns devices that are not marked as hidden. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve + user_id: The ID of the user which owns the device + device_id: The ID of the device to retrieve Returns: defer.Deferred for a dict containing the device information Raises: @@ -73,19 +66,17 @@ class DeviceWorkerStore(SQLBaseStore): desc="get_device", ) - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): + async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: """Retrieve all of a user's registered devices. Only returns devices that are not marked as hidden. Args: - user_id (str): + user_id: Returns: - defer.Deferred: resolves to a dict from device_id to a dict - containing "device_id", "user_id" and "display_name" for each - device. + A mapping from device_id to a dict containing "device_id", "user_id" + and "display_name" for each device. """ - devices = yield self.db_pool.simple_select_list( + devices = await self.db_pool.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -95,19 +86,20 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} @trace - @defer.inlineCallbacks - def get_device_updates_by_remote(self, destination, from_stream_id, limit): + async def get_device_updates_by_remote( + self, destination: str, from_stream_id: int, limit: int + ) -> Tuple[int, List[Tuple[str, dict]]]: """Get a stream of device updates to send to the given remote server. Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - limit (int): Maximum number of device updates to return + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive + limit: Maximum number of device updates to return + Returns: - Deferred[tuple[int, list[tuple[string,dict]]]]: - current stream id (ie, the stream id of the last update included in the - response), and the list of updates, where each update is a pair of EDU - type and EDU contents + A mapping from the current stream id (ie, the stream id of the last + update included in the response), and the list of updates, where + each update is a pair of EDU type and EDU contents. """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -117,7 +109,7 @@ class DeviceWorkerStore(SQLBaseStore): if not has_changed: return now_stream_id, [] - updates = yield self.db_pool.runInteraction( + updates = await self.db_pool.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, @@ -136,9 +128,7 @@ class DeviceWorkerStore(SQLBaseStore): master_key_by_user = {} self_signing_key_by_user = {} for user in users: - cross_signing_key = yield defer.ensureDeferred( - self.get_e2e_cross_signing_key(user, "master") - ) + cross_signing_key = await self.get_e2e_cross_signing_key(user, "master") if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( cross_signing_key @@ -151,8 +141,8 @@ class DeviceWorkerStore(SQLBaseStore): "device_id": verify_key.version, } - cross_signing_key = yield defer.ensureDeferred( - self.get_e2e_cross_signing_key(user, "self_signing") + cross_signing_key = await self.get_e2e_cross_signing_key( + user, "self_signing" ) if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( @@ -202,7 +192,7 @@ class DeviceWorkerStore(SQLBaseStore): if update_stream_id > previous_update_stream_id: query_map[key] = (update_stream_id, update_context) - results = yield self._get_device_update_edus_by_remote( + results = await self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) @@ -215,16 +205,21 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, results def _get_device_updates_by_remote_txn( - self, txn, destination, from_stream_id, now_stream_id, limit + self, + txn: LoggingTransaction, + destination: str, + from_stream_id: int, + now_stream_id: int, + limit: int, ): """Return device update information for a given remote destination Args: - txn (LoggingTransaction): The transaction to execute - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - now_stream_id (int): The maximum stream_id to filter updates by, inclusive - limit (int): Maximum number of device updates to return + txn: The transaction to execute + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive + now_stream_id: The maximum stream_id to filter updates by, inclusive + limit: Maximum number of device updates to return Returns: List: List of device updates @@ -240,23 +235,26 @@ class DeviceWorkerStore(SQLBaseStore): return list(txn) - @defer.inlineCallbacks - def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): + async def _get_device_update_edus_by_remote( + self, + destination: str, + from_stream_id: int, + query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]], + ) -> List[Tuple[str, dict]]: """Returns a list of device update EDUs as well as E2EE keys Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping user_id/device_id to update stream_id and the relevant json-encoded opentracing context Returns: - List[Dict]: List of objects representing an device update EDU - + List of objects representing an device update EDU """ devices = ( - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_get_e2e_device_keys_txn", self._get_e2e_device_keys_txn, query_map.keys(), @@ -271,7 +269,7 @@ class DeviceWorkerStore(SQLBaseStore): for user_id, user_devices in devices.items(): # The prev_id for the first row is always the last row before # `from_stream_id` - prev_id = yield self._get_last_device_update_for_remote_user( + prev_id = await self._get_last_device_update_for_remote_user( destination, user_id, from_stream_id ) @@ -315,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore): return results def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id + self, destination: str, user_id: str, from_stream_id: int ): def f(txn): prev_sent_id_sql = """ @@ -329,7 +327,7 @@ class DeviceWorkerStore(SQLBaseStore): return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) - def mark_as_sent_devices_by_remote(self, destination, stream_id): + def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int): """Mark that updates have successfully been sent to the destination. """ return self.db_pool.runInteraction( @@ -339,7 +337,9 @@ class DeviceWorkerStore(SQLBaseStore): stream_id, ) - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + def _mark_as_sent_devices_by_remote_txn( + self, txn: LoggingTransaction, destination: str, stream_id: int + ) -> None: # We update the device_lists_outbound_last_success with the successfully # poked users. sql = """ @@ -367,17 +367,21 @@ class DeviceWorkerStore(SQLBaseStore): """ txn.execute(sql, (destination, stream_id)) - @defer.inlineCallbacks - def add_user_signature_change_to_streams(self, from_user_id, user_ids): + async def add_user_signature_change_to_streams( + self, from_user_id: str, user_ids: List[str] + ) -> int: """Persist that a user has made new signatures Args: - from_user_id (str): the user who made the signatures - user_ids (list[str]): the users who were signed + from_user_id: the user who made the signatures + user_ids: the users who were signed + + Returns: + THe new stream ID. """ with self._device_list_id_gen.get_next() as stream_id: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, from_user_id, @@ -386,7 +390,13 @@ class DeviceWorkerStore(SQLBaseStore): ) return stream_id - def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): + def _add_user_signature_change_txn( + self, + txn: LoggingTransaction, + from_user_id: str, + user_ids: List[str], + stream_id: int, + ) -> None: txn.call_after( self._user_signature_stream_cache.entity_has_changed, from_user_id, @@ -402,29 +412,30 @@ class DeviceWorkerStore(SQLBaseStore): }, ) - def get_device_stream_token(self): + def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @trace - @defer.inlineCallbacks - def get_user_devices_from_cache(self, query_list): + async def get_user_devices_from_cache( + self, query_list: List[Tuple[str, str]] + ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. Args: - query_list(list): List of (user_id, device_ids), if device_ids is + query_list: List of (user_id, device_ids), if device_ids is falsey then return all device ids for that user. Returns: - (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is - a set of user_ids and results_map is a mapping of - user_id -> device_id -> device_info + A tuple of (user_ids_not_in_cache, results_map), where + user_ids_not_in_cache is a set of user_ids and results_map is a + mapping of user_id -> device_id -> device_info. """ user_ids = {user_id for user_id, _ in query_list} - user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids)) # We go and check if any of the users need to have their device lists # resynced. If they do then we remove them from the cached list. - users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( + users_needing_resync = await self.get_user_ids_requiring_device_list_resync( user_ids ) user_ids_in_cache = { @@ -438,19 +449,19 @@ class DeviceWorkerStore(SQLBaseStore): continue if device_id: - device = yield self._get_cached_user_device(user_id, device_id) + device = await self._get_cached_user_device(user_id, device_id) results.setdefault(user_id, {})[device_id] = device else: - results[user_id] = yield self.get_cached_devices_for_user(user_id) + results[user_id] = await self.get_cached_devices_for_user(user_id) set_tag("in_cache", results) set_tag("not_in_cache", user_ids_not_in_cache) return user_ids_not_in_cache, results - @cachedInlineCallbacks(num_args=2, tree=True) - def _get_cached_user_device(self, user_id, device_id): - content = yield self.db_pool.simple_select_one_onecol( + @cached(num_args=2, tree=True) + async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: + content = await self.db_pool.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -458,9 +469,9 @@ class DeviceWorkerStore(SQLBaseStore): ) return db_to_json(content) - @cachedInlineCallbacks() - def get_cached_devices_for_user(self, user_id): - devices = yield self.db_pool.simple_select_list( + @cached() + async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: + devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -470,11 +481,11 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id): + def get_devices_with_keys_by_user(self, user_id: str): """Get all devices (with any device keys) for a user Returns: - (stream_id, devices) + Deferred which resolves to (stream_id, devices) """ return self.db_pool.runInteraction( "get_devices_with_keys_by_user", @@ -482,7 +493,9 @@ class DeviceWorkerStore(SQLBaseStore): user_id, ) - def _get_devices_with_keys_by_user_txn(self, txn, user_id): + def _get_devices_with_keys_by_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: now_stream_id = self._device_list_id_gen.get_current_token() devices = self._get_e2e_device_keys_txn( @@ -515,17 +528,18 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, [] - def get_users_whose_devices_changed(self, from_key, user_ids): + async def get_users_whose_devices_changed( + self, from_key: str, user_ids: Iterable[str] + ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key (str): The device lists stream token - user_ids (Iterable[str]) + from_key: The device lists stream token + user_ids: The user IDs to query for devices. Returns: - Deferred[set[str]]: The set of user_ids whose devices have changed - since `from_key` + The set of user_ids whose devices have changed since `from_key` """ from_key = int(from_key) @@ -536,7 +550,7 @@ class DeviceWorkerStore(SQLBaseStore): ) if not to_check: - return defer.succeed(set()) + return set() def _get_users_whose_devices_changed_txn(txn): changes = set() @@ -556,18 +570,22 @@ class DeviceWorkerStore(SQLBaseStore): return changes - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) - @defer.inlineCallbacks - def get_users_whose_signatures_changed(self, user_id, from_key): + async def get_users_whose_signatures_changed( + self, user_id: str, from_key: str + ) -> Set[str]: """Get the users who have new cross-signing signatures made by `user_id` since `from_key`. Args: - user_id (str): the user who made the signatures - from_key (str): The device lists stream token + user_id: the user who made the signatures + from_key: The device lists stream token + + Returns: + A set of user IDs with updated signatures. """ from_key = int(from_key) if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): @@ -575,7 +593,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self.db_pool.execute( + rows = await self.db_pool.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return {user for row in rows for user in db_to_json(row[0])} @@ -638,7 +656,7 @@ class DeviceWorkerStore(SQLBaseStore): ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id): + def get_device_list_last_stream_id_for_remote(self, user_id: str): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ @@ -655,7 +673,7 @@ class DeviceWorkerStore(SQLBaseStore): list_name="user_ids", inlineCallbacks=True, ) - def get_device_list_last_stream_id_for_remotes(self, user_ids): + def get_device_list_last_stream_id_for_remotes(self, user_ids: str): rows = yield self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", @@ -669,8 +687,7 @@ class DeviceWorkerStore(SQLBaseStore): return results - @defer.inlineCallbacks - def get_user_ids_requiring_device_list_resync( + async def get_user_ids_requiring_device_list_resync( self, user_ids: Optional[Collection[str]] = None, ) -> Set[str]: """Given a list of remote users return the list of users that we @@ -681,7 +698,7 @@ class DeviceWorkerStore(SQLBaseStore): The IDs of users whose device lists need resync. """ if user_ids: - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_resync", column="user_id", iterable=user_ids, @@ -689,7 +706,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="get_user_ids_requiring_device_list_resync_with_iterable", ) else: - rows = yield self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="device_lists_remote_resync", keyvalues=None, retcols=("user_id",), @@ -710,7 +727,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) - def mark_remote_user_device_list_as_unsubscribed(self, user_id): + def mark_remote_user_device_list_as_unsubscribed(self, user_id: str): """Mark that we no longer track device lists for remote user. """ @@ -779,16 +796,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): "drop_device_lists_outbound_last_success_non_unique_idx", ) - @defer.inlineCallbacks - def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.close() - yield self.db_pool.runWithConnection(f) - yield self.db_pool.updates._end_background_update( + await self.db_pool.runWithConnection(f) + await self.db_pool.updates._end_background_update( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES ) return 1 @@ -868,18 +884,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) - @defer.inlineCallbacks - def store_device(self, user_id, device_id, initial_device_display_name): + async def store_device( + self, user_id: str, device_id: str, initial_device_display_name: str + ) -> bool: """Ensure the given device is known; add it to the store if not Args: - user_id (str): id of user associated with the device - device_id (str): id of device - initial_device_display_name (str): initial displayname of the - device. Ignored if device exists. + user_id: id of user associated with the device + device_id: id of device + initial_device_display_name: initial displayname of the device. + Ignored if device exists. + Returns: - defer.Deferred: boolean whether the device was inserted or an - existing device existed with that ID. + Whether the device was inserted or an existing device existed with that ID. + Raises: StoreError: if the device is already in use """ @@ -888,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self.db_pool.simple_insert( + inserted = await self.db_pool.simple_insert( "devices", values={ "user_id": user_id, @@ -902,7 +920,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self.db_pool.simple_select_one_onecol( + hidden = await self.db_pool.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -927,17 +945,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): + async def delete_device(self, user_id: str, device_id: str) -> None: """Delete a device. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to delete - Returns: - defer.Deferred + user_id: The ID of the user which owns the device + device_id: The ID of the device to delete """ - yield self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -945,17 +960,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.device_id_exists_cache.invalidate((user_id, device_id)) - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: - user_id (str): The ID of the user which owns the devices - device_ids (list): The IDs of the devices to delete - Returns: - defer.Deferred + user_id: The ID of the user which owns the devices + device_ids: The IDs of the devices to delete """ - yield self.db_pool.simple_delete_many( + await self.db_pool.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -965,26 +977,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) - def update_device(self, user_id, device_id, new_display_name=None): + async def update_device( + self, user_id: str, device_id: str, new_display_name: Optional[str] = None + ) -> None: """Update a device. Only updates the device if it is not marked as hidden. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to update - new_display_name (str|None): new displayname for device; None - to leave unchanged + user_id: The ID of the user which owns the device + device_id: The ID of the device to update + new_display_name: new displayname for device; None to leave unchanged Raises: StoreError: if the device is not found - Returns: - defer.Deferred """ updates = {} if new_display_name is not None: updates["display_name"] = new_display_name if not updates: - return defer.succeed(None) - return self.db_pool.simple_update_one( + return None + await self.db_pool.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -992,7 +1003,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def update_remote_device_list_cache_entry( - self, user_id, device_id, content, stream_id + self, user_id: str, device_id: str, content: JsonDict, stream_id: int ): """Updates a single device in the cache of a remote user's devicelist. @@ -1000,10 +1011,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device list. Args: - user_id (str): User to update device list for - device_id (str): ID of decivice being updated - content (dict): new data on this device - stream_id (int): the version of the device list + user_id: User to update device list for + device_id: ID of decivice being updated + content: new data on this device + stream_id: the version of the device list Returns: Deferred[None] @@ -1018,8 +1029,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _update_remote_device_list_cache_entry_txn( - self, txn, user_id, device_id, content, stream_id - ): + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + content: JsonDict, + stream_id: int, + ) -> None: if content.get("deleted"): self.db_pool.simple_delete_txn( txn, @@ -1055,16 +1071,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): lock=False, ) - def update_remote_device_list_cache(self, user_id, devices, stream_id): + def update_remote_device_list_cache( + self, user_id: str, devices: List[dict], stream_id: int + ): """Replace the entire cache of the remote user's devices. Note: assumes that we are the only thread that can be updating this user's device list. Args: - user_id (str): User to update device list for - devices (list[dict]): list of device objects supplied over federation - stream_id (int): the version of the device list + user_id: User to update device list for + devices: list of device objects supplied over federation + stream_id: the version of the device list Returns: Deferred[None] @@ -1077,7 +1095,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): stream_id, ) - def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): + def _update_remote_device_list_cache_txn( + self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int + ): self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) @@ -1118,8 +1138,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, ) - @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_ids, hosts): + async def add_device_change_to_streams( + self, user_id: str, device_ids: Collection[str], hosts: List[str] + ): """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ @@ -1127,7 +1148,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, user_id, @@ -1142,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_device_outbound_poke_to_stream", self._add_device_outbound_poke_to_stream_txn, user_id, @@ -1187,7 +1208,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _add_device_outbound_poke_to_stream_txn( - self, txn, user_id, device_ids, hosts, stream_ids, context, + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + hosts: List[str], + stream_ids: List[str], + context: Dict[str, str], ): for host in hosts: txn.call_after( @@ -1219,7 +1246,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ], ) - def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): + def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure that we don't fill up due to dead servers. diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 64ddd8243d..64afd581bc 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -116,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_device_updates_by_remote.return_value = defer.succeed( + self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable( (0, []) ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index c2539b353a..87ed8f8cd1 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -34,7 +34,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_store_new_device(self): - yield self.store.store_device("user_id", "device_id", "display_name") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device_id", "display_name") + ) res = yield self.store.get_device("user_id", "device_id") self.assertDictContainsSubset( @@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_devices_by_user(self): - yield self.store.store_device("user_id", "device1", "display_name 1") - yield self.store.store_device("user_id", "device2", "display_name 2") - yield self.store.store_device("user_id2", "device3", "display_name 3") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device1", "display_name 1") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id", "device2", "display_name 2") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id2", "device3", "display_name 3") + ) - res = yield self.store.get_devices_by_user("user_id") + res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id")) self.assertEqual(2, len(res.keys())) self.assertDictContainsSubset( { @@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id - yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["somehost"] + yield defer.ensureDeferred( + self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "somehost", -1, limit=100 + now_stream_id, device_updates = yield defer.ensureDeferred( + self.store.get_device_updates_by_remote("somehost", -1, limit=100) ) # Check original device_ids are contained within these updates @@ -99,19 +107,23 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_update_device(self): - yield self.store.store_device("user_id", "device_id", "display_name 1") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device_id", "display_name 1") + ) res = yield self.store.get_device("user_id", "device_id") self.assertEqual("display_name 1", res["display_name"]) # do a no-op first - yield self.store.update_device("user_id", "device_id") + yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) res = yield self.store.get_device("user_id", "device_id") self.assertEqual("display_name 1", res["display_name"]) # do the update - yield self.store.update_device( - "user_id", "device_id", new_display_name="display_name 2" + yield defer.ensureDeferred( + self.store.update_device( + "user_id", "device_id", new_display_name="display_name 2" + ) ) # check it worked @@ -121,7 +133,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_update_unknown_device(self): with self.assertRaises(synapse.api.errors.StoreError) as cm: - yield self.store.update_device( - "user_id", "unknown_device_id", new_display_name="display_name 2" + yield defer.ensureDeferred( + self.store.update_device( + "user_id", "unknown_device_id", new_display_name="display_name 2" + ) ) self.assertEqual(404, cm.exception.code) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 9f8d30373b..d57cdffd8b 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -30,7 +30,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device("user", "device", None) + yield defer.ensureDeferred(self.store.store_device("user", "device", None)) yield self.store.set_e2e_device_keys("user", "device", now, json) @@ -47,7 +47,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device("user", "device", None) + yield defer.ensureDeferred(self.store.store_device("user", "device", None)) changed = yield self.store.set_e2e_device_keys("user", "device", now, json) self.assertTrue(changed) @@ -63,7 +63,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): json = {"key": "value"} yield self.store.set_e2e_device_keys("user", "device", now, json) - yield self.store.store_device("user", "device", "display_name") + yield defer.ensureDeferred( + self.store.store_device("user", "device", "display_name") + ) res = yield defer.ensureDeferred( self.store.get_e2e_device_keys((("user", "device"),)) @@ -79,10 +81,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): def test_multiple_devices(self): now = 1470174257070 - yield self.store.store_device("user1", "device1", None) - yield self.store.store_device("user1", "device2", None) - yield self.store.store_device("user2", "device1", None) - yield self.store.store_device("user2", "device2", None) + yield defer.ensureDeferred(self.store.store_device("user1", "device1", None)) + yield defer.ensureDeferred(self.store.store_device("user1", "device2", None)) + yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) + yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) -- cgit 1.5.1 From fbe930dad28c81a5e563ddc8683f65b8279aad52 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 12:14:34 -0400 Subject: Convert the roommember database to async/await. (#8070) --- changelog.d/8070.misc | 1 + synapse/storage/_base.py | 1 - synapse/storage/databases/main/push_rule.py | 75 -------- synapse/storage/databases/main/roommember.py | 263 ++++++++++----------------- tests/test_federation.py | 18 +- 5 files changed, 116 insertions(+), 242 deletions(-) create mode 100644 changelog.d/8070.misc (limited to 'tests') diff --git a/changelog.d/8070.misc b/changelog.d/8070.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8070.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index ca800df831..6814bf5fcf 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -58,7 +58,6 @@ class SQLBaseStore(metaclass=ABCMeta): """ for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) - self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 19a0211a03..6562db5c2b 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -256,81 +256,6 @@ class PushRulesWorkerStore( ): yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) - @defer.inlineCallbacks - def bulk_get_push_rules_for_room(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) - result = yield self._bulk_get_push_rules_for_room( - event.room_id, state_group, current_state_ids, event=event - ) - return result - - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room( - self, room_id, state_group, current_state_ids, cache_context, event=None - ): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - users_in_room = yield self._get_joined_users_from_context( - room_id, - state_group, - current_state_ids, - on_invalidate=cache_context.invalidate, - event=event, - ) - - # We ignore app service users for now. This is so that we don't fill - # up the `get_if_users_have_pushers` cache with AS entries that we - # know don't have pushers, nor even read receipts. - local_users_in_room = { - u - for u in users_in_room - if self.hs.is_mine_id(u) - and not self.get_if_app_services_interested_in_user(u) - } - - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, on_invalidate=cache_context.invalidate - ) - user_ids = { - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - } - - users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - - rules_by_user = yield self.bulk_get_push_rules( - user_ids, on_invalidate=cache_context.invalidate - ) - - rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - - return rules_by_user - @cachedList( cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 7c5be251bd..b2fcfc9bfe 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,11 +15,13 @@ # limitations under the License. import logging -from typing import Iterable, List, Set +from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import ( @@ -40,9 +42,12 @@ from synapse.storage.roommember import ( from synapse.types import Collection, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.state import _StateCacheEntry + logger = logging.getLogger(__name__) @@ -150,12 +155,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id): + def get_users_in_room(self, room_id: str): return self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) - def get_users_in_room_txn(self, txn, room_id): + def get_users_in_room_txn(self, txn, room_id: str) -> List[str]: # If we can assume current_state_events.membership is up to date # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. @@ -178,11 +183,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [r[0] for r in txn] @cached(max_entries=100000) - def get_room_summary(self, room_id): + def get_room_summary(self, room_id: str): """ Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: - room_id (str): The room ID to query + room_id: The room ID to query Returns: Deferred[dict[str, MemberSummary]: dict of membership states, pointing to a MemberSummary named tuple. @@ -261,78 +266,59 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) - def _get_user_counts_in_room_txn(self, txn, room_id): - """ - Get the user count in a room by membership. - - Args: - room_id (str) - membership (Membership) - - Returns: - Deferred[int] - """ - sql = """ - SELECT m.membership, count(*) FROM room_memberships as m - INNER JOIN current_state_events as c USING(event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - return {row[0]: row[1] for row in txn} - @cached() - def get_invited_rooms_for_local_user(self, user_id): - """ Get all the rooms the *local* user is invited to + def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]: + """Get all the rooms the *local* user is invited to. Args: - user_id (str): The user ID. + user_id: The user ID. + Returns: - A deferred list of RoomsForUser. + A awaitable list of RoomsForUser. """ return self.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE] ) - @defer.inlineCallbacks - def get_invite_for_local_user_in_room(self, user_id, room_id): - """Gets the invite for the given *local* user and room + async def get_invite_for_local_user_in_room( + self, user_id: str, room_id: str + ) -> Optional[RoomsForUser]: + """Gets the invite for the given *local* user and room. Args: - user_id (str) - room_id (str) + user_id: The user ID to find the invite of. + room_id: The room to user was invited to. Returns: - Deferred: Resolves to either a RoomsForUser or None if no invite was - found. + Either a RoomsForUser or None if no invite was found. """ - invites = yield self.get_invited_rooms_for_local_user(user_id) + invites = await self.get_invited_rooms_for_local_user(user_id) for invite in invites: if invite.room_id == room_id: return invite return None - @defer.inlineCallbacks - def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this *local* user where the membership for this user + async def get_rooms_for_local_user_where_membership_is( + self, user_id: str, membership_list: List[str] + ) -> Optional[List[RoomsForUser]]: + """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. Filters out forgotten rooms. Args: - user_id (str): The user ID. - membership_list (list): A list of synapse.api.constants.Membership - values which the user must be in. + user_id: The user ID. + membership_list: A list of synapse.api.constants.Membership + values which the user must be in. Returns: - Deferred[list[RoomsForUser]] + The RoomsForUser that the user matches the membership types. """ if not membership_list: - return defer.succeed(None) + return None - rooms = yield self.db_pool.runInteraction( + rooms = await self.db_pool.runInteraction( "get_rooms_for_local_user_where_membership_is", self._get_rooms_for_local_user_where_membership_is_txn, user_id, @@ -340,12 +326,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) # Now we filter out forgotten rooms - forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) + forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id) return [room for room in rooms if room.room_id not in forgotten_rooms] def _get_rooms_for_local_user_where_membership_is_txn( - self, txn, user_id, membership_list - ): + self, txn, user_id: str, membership_list: List[str] + ) -> List[RoomsForUser]: # Paranoia check. if not self.hs.is_mine_id(user_id): raise Exception( @@ -374,14 +360,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results @cached(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id): + def get_rooms_for_user_with_stream_ordering(self, user_id: str): """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. Args: - user_id (str) + user_id Returns: Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns @@ -394,7 +380,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_id, ) - def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): + def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str): # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called # for rooms the server is participating in. @@ -458,37 +444,39 @@ class RoomMemberWorkerStore(EventsWorkerStore): _get_users_server_still_shares_room_with_txn, ) - @defer.inlineCallbacks - def get_rooms_for_user(self, user_id, on_invalidate=None): + async def get_rooms_for_user(self, user_id: str, on_invalidate=None): """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. """ - rooms = yield self.get_rooms_for_user_with_stream_ordering( + rooms = await self.get_rooms_for_user_with_stream_ordering( user_id, on_invalidate=on_invalidate ) return frozenset(r.room_id for r in rooms) - @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) - def get_users_who_share_room_with_user(self, user_id, cache_context): + @cached(max_entries=500000, cache_context=True, iterable=True) + async def get_users_who_share_room_with_user( + self, user_id: str, cache_context: _CacheContext + ) -> Set[str]: """Returns the set of users who share a room with `user_id` """ - room_ids = yield self.get_rooms_for_user( + room_ids = await self.get_rooms_for_user( user_id, on_invalidate=cache_context.invalidate ) user_who_share_room = set() for room_id in room_ids: - user_ids = yield self.get_users_in_room( + user_ids = await self.get_users_in_room( room_id, on_invalidate=cache_context.invalidate ) user_who_share_room.update(user_ids) return user_who_share_room - @defer.inlineCallbacks - def get_joined_users_from_context(self, event, context): + async def get_joined_users_from_context( + self, event: EventBase, context: EventContext + ): state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -497,14 +485,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) - result = yield self._get_joined_users_from_context( + current_state_ids = await context.get_current_state_ids() + return await self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) - return result - @defer.inlineCallbacks - def get_joined_users_from_state(self, room_id, state_entry): + async def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -514,16 +500,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() with Measure(self._clock, "get_joined_users_from_state"): - return ( - yield self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) + return await self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry ) - @cachedInlineCallbacks( - num_args=2, cache_context=True, iterable=True, max_entries=100000 - ) - def _get_joined_users_from_context( + @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) + async def _get_joined_users_from_context( self, room_id, state_group, @@ -535,7 +517,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None users_in_room = {} @@ -588,7 +569,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): missing_member_event_ids.append(event_id) if missing_member_event_ids: - event_to_memberships = yield self._get_joined_profiles_from_event_ids( + event_to_memberships = await self._get_joined_profiles_from_event_ids( missing_member_event_ids ) users_in_room.update((row for row in event_to_memberships.values() if row)) @@ -612,12 +593,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): list_name="event_ids", inlineCallbacks=True, ) - def _get_joined_profiles_from_event_ids(self, event_ids): + def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. Args: - event_ids (Iterable[str]): The member event IDs to lookup + event_ids: The member event IDs to lookup Returns: Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID @@ -644,8 +625,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): for row in rows } - @cachedInlineCallbacks(max_entries=10000) - def is_host_joined(self, room_id, host): + @cached(max_entries=10000) + async def is_host_joined(self, room_id: str, host: str) -> bool: if "%" in host or "_" in host: raise Exception("Invalid host name") @@ -664,7 +645,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self.db_pool.execute( + rows = await self.db_pool.execute( "is_host_joined", None, sql, room_id, like_clause ) @@ -678,50 +659,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True - @cachedInlineCallbacks() - def was_host_joined(self, room_id, host): - """Check whether the server is or ever was in the room. - - Args: - room_id (str) - host (str) - - Returns: - Deferred: Resolves to True if the host is/was in the room, otherwise - False. - """ - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT user_id FROM room_memberships - WHERE room_id = ? - AND user_id LIKE ? - AND membership = 'join' - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self.db_pool.execute( - "was_host_joined", None, sql, room_id, like_clause - ) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @defer.inlineCallbacks - def get_joined_hosts(self, room_id, state_entry): + async def get_joined_hosts(self, room_id: str, state_entry): state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -731,32 +669,28 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() with Measure(self._clock, "get_joined_hosts"): - return ( - yield self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry - ) + return await self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry ) - @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - # @defer.inlineCallbacks - def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): + @cached(num_args=2, max_entries=10000, iterable=True) + async def _get_joined_hosts( + self, room_id, state_group, current_state_ids, state_entry + ): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - cache = yield self._get_joined_hosts_cache(room_id) - joined_hosts = yield cache.get_destinations(state_entry) - - return joined_hosts + cache = await self._get_joined_hosts_cache(room_id) + return await cache.get_destinations(state_entry) @cached(max_entries=10000) - def _get_joined_hosts_cache(self, room_id): + def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": return _JoinedHostsCache(self, room_id) - @cachedInlineCallbacks(num_args=2) - def did_forget(self, user_id, room_id): + @cached(num_args=2) + async def did_forget(self, user_id: str, room_id: str) -> bool: """Returns whether user_id has elected to discard history for room_id. Returns False if they have since re-joined.""" @@ -778,15 +712,15 @@ class RoomMemberWorkerStore(EventsWorkerStore): rows = txn.fetchall() return rows[0][0] - count = yield self.db_pool.runInteraction("did_forget_membership", f) + count = await self.db_pool.runInteraction("did_forget_membership", f) return count == 0 @cached() - def get_forgotten_rooms_for_user(self, user_id): + def get_forgotten_rooms_for_user(self, user_id: str): """Gets all rooms the user has forgotten. Args: - user_id (str) + user_id Returns: Deferred[set[str]] @@ -819,18 +753,17 @@ class RoomMemberWorkerStore(EventsWorkerStore): "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) - @defer.inlineCallbacks - def get_rooms_user_has_been_in(self, user_id): + async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]: """Get all rooms that the user has ever been in. Args: - user_id (str) + user_id: The user ID to get the rooms of. Returns: - Deferred[set[str]]: Set of room IDs. + Set of room IDs. """ - room_ids = yield self.db_pool.simple_select_onecol( + room_ids = await self.db_pool.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -905,8 +838,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): where_clause="forgotten = 1", ) - @defer.inlineCallbacks - def _background_add_membership_profile(self, progress, batch_size): + async def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( "target_min_stream_id_inclusive", self._min_stream_order_on_start ) @@ -971,19 +903,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): return len(rows) - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( _MEMBERSHIP_PROFILE_UPDATE_NAME ) return result - @defer.inlineCallbacks - def _background_current_state_membership(self, progress, batch_size): + async def _background_current_state_membership(self, progress, batch_size): """Update the new membership column on current_state_events. This works by iterating over all rooms in alphebetical order. @@ -1029,14 +960,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): # string, which will compare before all room IDs correctly. last_processed_room = progress.get("last_processed_room", "") - row_count, finished = yield self.db_pool.runInteraction( + row_count, finished = await self.db_pool.runInteraction( "_background_current_state_membership_update", _background_current_state_membership_txn, last_processed_room, ) if finished: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME ) @@ -1047,7 +978,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - def forget(self, user_id, room_id): + def forget(self, user_id: str, room_id: str): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -1088,17 +1019,19 @@ class _JoinedHostsCache(object): self._len = 0 - @defer.inlineCallbacks - def get_destinations(self, state_entry): + async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]: """Get set of destinations for a state entry Args: - state_entry(synapse.state._StateCacheEntry) + state_entry + + Returns: + The destinations as a set. """ if state_entry.state_group == self.state_group: return frozenset(self.hosts_to_joined_users) - with (yield self.linearizer.queue(())): + with (await self.linearizer.queue(())): if state_entry.state_group == self.state_group: pass elif state_entry.prev_group == self.state_group: @@ -1110,7 +1043,7 @@ class _JoinedHostsCache(object): user_id = state_key known_joins = self.hosts_to_joined_users.setdefault(host, set()) - event = yield self.store.get_event(event_id) + event = await self.store.get_event(event_id) if event.membership == Membership.JOIN: known_joins.add(user_id) else: @@ -1119,7 +1052,7 @@ class _JoinedHostsCache(object): if not known_joins: self.hosts_to_joined_users.pop(host, None) else: - joined_users = yield self.store.get_joined_users_from_state( + joined_users = await self.store.get_joined_users_from_state( self.room_id, state_entry ) diff --git a/tests/test_federation.py b/tests/test_federation.py index c2f12c2741..f2fa42bfb9 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,3 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from mock import Mock from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed @@ -10,6 +25,7 @@ from synapse.util.retryutils import NotRetryingDestination from tests import unittest from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver +from tests.test_utils import make_awaitable class MessageAcceptTests(unittest.HomeserverTestCase): @@ -173,7 +189,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.homeserver.get_datastore() - store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"])) + store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. -- cgit 1.5.1 From dd8f28bd3fedb74080916cf0d03e6957b2978651 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 13 Aug 2020 07:11:39 -0400 Subject: Fix unawaited coroutine error in tests. (#8072) --- changelog.d/8072.misc | 1 + tests/federation/test_complexity.py | 30 ++++++++++++++++++++---------- 2 files changed, 21 insertions(+), 10 deletions(-) create mode 100644 changelog.d/8072.misc (limited to 'tests') diff --git a/changelog.d/8072.misc b/changelog.d/8072.misc new file mode 100644 index 0000000000..e26764dea1 --- /dev/null +++ b/changelog.d/8072.misc @@ -0,0 +1 @@ + Convert various parts of the codebase to async/await. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index b8ca118716..9bd515080c 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -79,9 +79,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) handler.federation_handler.do_invite_join = Mock( - return_value=make_awaitable(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) d = handler._remote_join( @@ -110,9 +112,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) handler.federation_handler.do_invite_join = Mock( - return_value=make_awaitable(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) d = handler._remote_join( @@ -148,9 +152,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable(None) + ) handler.federation_handler.do_invite_join = Mock( - return_value=make_awaitable(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) # Artificially raise the complexity @@ -204,9 +210,11 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) handler.federation_handler.do_invite_join = Mock( - return_value=make_awaitable(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) d = handler._remote_join( @@ -234,9 +242,11 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) handler.federation_handler.do_invite_join = Mock( - return_value=make_awaitable(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) d = handler._remote_join( -- cgit 1.5.1 From ac77cdb64e50c9fdfc00cccbc7b96f42057aa741 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 12:37:59 -0400 Subject: Add a shadow-banned flag to users. (#8092) --- changelog.d/8092.feature | 1 + synapse/api/auth.py | 12 ++++++++++- synapse/handlers/register.py | 8 +++++++ synapse/replication/http/register.py | 4 ++++ synapse/storage/databases/main/registration.py | 9 +++++++- .../main/schema/delta/58/09shadow_ban.sql | 18 ++++++++++++++++ synapse/types.py | 25 +++++++++++++++++++--- tests/storage/test_cleanup_extrems.py | 4 ++-- tests/storage/test_event_metrics.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/test_federation.py | 2 +- tests/unittest.py | 8 +++++-- 12 files changed, 83 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8092.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql (limited to 'tests') diff --git a/changelog.d/8092.feature b/changelog.d/8092.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8092.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/api/auth.py b/synapse/api/auth.py index d8190f92ab..7aab764360 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -213,6 +213,7 @@ class Auth(object): user = user_info["user"] token_id = user_info["token_id"] is_guest = user_info["is_guest"] + shadow_banned = user_info["shadow_banned"] # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: @@ -252,7 +253,12 @@ class Auth(object): opentracing.set_tag("device_id", device_id) return synapse.types.create_requester( - user, token_id, is_guest, device_id, app_service=app_service + user, + token_id, + is_guest, + shadow_banned, + device_id, + app_service=app_service, ) except KeyError: raise MissingClientTokenError() @@ -297,6 +303,7 @@ class Auth(object): dict that includes: `user` (UserID) `is_guest` (bool) + `shadow_banned` (bool) `token_id` (int|None): access token id. May be None if guest `device_id` (str|None): device corresponding to access token Raises: @@ -356,6 +363,7 @@ class Auth(object): ret = { "user": user, "is_guest": True, + "shadow_banned": False, "token_id": None, # all guests get the same device id "device_id": GUEST_DEVICE_ID, @@ -365,6 +373,7 @@ class Auth(object): ret = { "user": user, "is_guest": False, + "shadow_banned": False, "token_id": None, "device_id": None, } @@ -488,6 +497,7 @@ class Auth(object): "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), "is_guest": False, + "shadow_banned": ret.get("shadow_banned"), "device_id": ret.get("device_id"), "valid_until_ms": ret.get("valid_until_ms"), } diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c94209ab3d..999bc6efb5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -142,6 +142,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, + shadow_banned=False, ): """Registers a new client on the server. @@ -159,6 +160,7 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. + shadow_banned (bool): Shadow-ban the created user. Returns: str: user_id Raises: @@ -194,6 +196,7 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, address=address, + shadow_banned=shadow_banned, ) if self.hs.config.user_directory_search_all_users: @@ -224,6 +227,7 @@ class RegistrationHandler(BaseHandler): make_guest=make_guest, create_profile_with_displayname=default_display_name, address=address, + shadow_banned=shadow_banned, ) # Successfully registered @@ -529,6 +533,7 @@ class RegistrationHandler(BaseHandler): admin=False, user_type=None, address=None, + shadow_banned=False, ): """Register user in the datastore. @@ -546,6 +551,7 @@ class RegistrationHandler(BaseHandler): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the registration. + shadow_banned (bool): Whether to shadow-ban the user Returns: Awaitable @@ -561,6 +567,7 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, address=address, + shadow_banned=shadow_banned, ) else: return self.store.register_user( @@ -572,6 +579,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=create_profile_with_displayname, admin=admin, user_type=user_type, + shadow_banned=shadow_banned, ) async def register_device( diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index ce9420aa69..a02b27474d 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin, user_type, address, + shadow_banned, ): """ Args: @@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the regitration. + shadow_banned (bool): Whether to shadow-ban the user """ return { "password_hash": password_hash, @@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "admin": admin, "user_type": user_type, "address": address, + "shadow_banned": shadow_banned, } async def _handle_request(self, request, user_id): @@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin=content["admin"], user_type=content["user_type"], address=content["address"], + shadow_banned=content["shadow_banned"], ) return 200, {} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7965a52e30..de50fa6e94 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -304,7 +304,7 @@ class RegistrationWorkerStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," + "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," " access_tokens.device_id, access_tokens.valid_until_ms" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" @@ -952,6 +952,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname=None, admin=False, user_type=None, + shadow_banned=False, ): """Attempts to register an account. @@ -968,6 +969,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + shadow_banned (bool): Whether the user is shadow-banned, + i.e. they may be told their requests succeeded but we ignore them. Raises: StoreError if the user_id could not be registered. @@ -986,6 +989,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ) def _register_user( @@ -999,6 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ): user_id_obj = UserID.from_string(user_id) @@ -1028,6 +1033,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) else: @@ -1042,6 +1048,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql new file mode 100644 index 0000000000..260b009b48 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A shadow-banned user may be told that their requests succeeded when they were +-- actually ignored. +ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN; diff --git a/synapse/types.py b/synapse/types.py index 9e580f4295..bc36cdde30 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -51,7 +51,15 @@ JsonDict = Dict[str, Any] class Requester( namedtuple( - "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"] + "Requester", + [ + "user", + "access_token_id", + "is_guest", + "shadow_banned", + "device_id", + "app_service", + ], ) ): """ @@ -62,6 +70,7 @@ class Requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request has been shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user """ @@ -77,6 +86,7 @@ class Requester( "user_id": self.user.to_string(), "access_token_id": self.access_token_id, "is_guest": self.is_guest, + "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, } @@ -101,13 +111,19 @@ class Requester( user=UserID.from_string(input["user_id"]), access_token_id=input["access_token_id"], is_guest=input["is_guest"], + shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, ) def create_requester( - user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None + user_id, + access_token_id=None, + is_guest=False, + shadow_banned=False, + device_id=None, + app_service=None, ): """ Create a new ``Requester`` object @@ -117,6 +133,7 @@ def create_requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request is shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user @@ -125,7 +142,9 @@ def create_requester( """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) - return Requester(user_id, access_token_id, is_guest, device_id, app_service) + return Requester( + user_id, access_token_id, is_guest, shadow_banned, device_id, app_service + ) def get_domain_from_id(string): diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 3fab5a5248..8e9a650f9f 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") - self.requester = Requester(self.user, None, False, None, None) + self.requester = Requester(self.user, None, False, False, None, None) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] @@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") - self.requester = Requester(self.user, None, False, None, None) + self.requester = Requester(self.user, None, False, False, None, None) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a7b85004e5..949846fe33 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): room_creator = self.hs.get_room_creation_handler() user = UserID("alice", "test") - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) # Real events, forward extremities events = [(3, 2), (6, 2), (4, 6)] diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 17c9da4838..d98fe8754d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Now let's create a room, which will insert a membership user = UserID("alice", "test") - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) self.get_success(self.room_creator.create_room(requester, {})) # Register the background update to run again. diff --git a/tests/test_federation.py b/tests/test_federation.py index f2fa42bfb9..4a4548433f 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -42,7 +42,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) user_id = UserID("us", "test") - our_user = Requester(user_id, None, False, None, None) + our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() room_deferred = ensureDeferred( room_creator.create_room( diff --git a/tests/unittest.py b/tests/unittest.py index d0bba3ddef..7b80999a74 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -250,7 +250,11 @@ class HomeserverTestCase(TestCase): async def get_user_by_req(request, allow_guest=False, rights="access"): return create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None + UserID.from_string(self.helper.auth_user_id), + 1, + False, + False, + None, ) self.hs.get_auth().get_user_by_req = get_user_by_req @@ -540,7 +544,7 @@ class HomeserverTestCase(TestCase): """ event_creator = self.hs.get_event_creation_handler() secrets = self.hs.get_secrets() - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) event, context = self.get_success( event_creator.create_event( -- cgit 1.5.1 From ad6190c9252aafd37cd8c229b70853bfc4ef0e64 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Aug 2020 07:24:46 -0400 Subject: Convert stream database to async/await. (#8074) --- changelog.d/8074.misc | 1 + synapse/api/filtering.py | 2 +- synapse/api/presence.py | 69 ++++ synapse/federation/send_queue.py | 2 +- synapse/federation/sender/__init__.py | 2 +- synapse/federation/sender/per_destination_queue.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/storage/databases/main/presence.py | 2 +- synapse/storage/databases/main/stream.py | 387 +++++++++++---------- synapse/storage/presence.py | 69 ---- tests/handlers/test_presence.py | 2 +- tests/storage/test_purge.py | 49 +-- 12 files changed, 293 insertions(+), 296 deletions(-) create mode 100644 changelog.d/8074.misc create mode 100644 synapse/api/presence.py delete mode 100644 synapse/storage/presence.py (limited to 'tests') diff --git a/changelog.d/8074.misc b/changelog.d/8074.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8074.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 7393d6cb74..a8937d2595 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -23,7 +23,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError -from synapse.storage.presence import UserPresenceState +from synapse.api.presence import UserPresenceState from synapse.types import RoomID, UserID FILTER_SCHEMA = { diff --git a/synapse/api/presence.py b/synapse/api/presence.py new file mode 100644 index 0000000000..18a462f0ee --- /dev/null +++ b/synapse/api/presence.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple + +from synapse.api.constants import PresenceState + + +class UserPresenceState( + namedtuple( + "UserPresenceState", + ( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + ) +): + """Represents the current presence state of the user. + + user_id (str) + last_active (int): Time in msec that the user last interacted with server. + last_federation_update (int): Time in msec since either a) we sent a presence + update to other servers or b) we received a presence update, depending + on if is a local user or not. + last_user_sync (int): Time in msec that the user last *completed* a sync + (or event stream). + status_msg (str): User set status message. + """ + + def as_dict(self): + return dict(self._asdict()) + + @staticmethod + def from_dict(d): + return UserPresenceState(**d) + + def copy_and_replace(self, **kwargs): + return self._replace(**kwargs) + + @classmethod + def default(cls, user_id): + """Returns a default presence state. + """ + return cls( + user_id=user_id, + state=PresenceState.OFFLINE, + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, + status_msg=None, + currently_active=False, + ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 2b0ab2dcbf..4d65d4aeea 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -37,8 +37,8 @@ from sortedcontainers import SortedDict from twisted.internet import defer +from synapse.api.presence import UserPresenceState from synapse.metrics import LaterGauge -from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure from .units import Edu diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 94cc63001e..e53b6ac456 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -22,6 +22,7 @@ from twisted.internet import defer import synapse import synapse.metrics +from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager @@ -39,7 +40,6 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.metrics import Measure, measure_func diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 8cbc23d901..c09ffcaf4c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -24,12 +24,12 @@ from synapse.api.errors import ( HttpResponseException, RequestSendFailed, ) +from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 5387b3724f..24e1940ee5 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -33,13 +33,13 @@ from typing_extensions import ContextManager import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError +from synapse.api.presence import UserPresenceState from synapse.logging.context import run_in_background from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.storage.presence import UserPresenceState from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9f691e5792..4e3ec02d14 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -15,8 +15,8 @@ from typing import List, Tuple +from synapse.api.presence import UserPresenceState from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index aaf225894e..8ccfb8fc46 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -39,15 +39,17 @@ what sort order was used: import abc import logging from collections import namedtuple -from typing import Optional +from typing import Dict, Iterable, List, Optional, Tuple from twisted.internet import defer +from synapse.api.filtering import Filter +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -68,8 +70,12 @@ _EventDictReturn = namedtuple( def generate_pagination_where_clause( - direction, column_names, from_token, to_token, engine -): + direction: str, + column_names: Tuple[str, str], + from_token: Optional[Tuple[int, int]], + to_token: Optional[Tuple[int, int]], + engine: BaseDatabaseEngine, +) -> str: """Creates an SQL expression to bound the columns by the pagination tokens. @@ -90,21 +96,19 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction (str): Whether we're paginating backwards("b") or - forwards ("f"). - column_names (tuple[str, str]): The column names to bound. Must *not* - be user defined as these get inserted directly into the SQL - statement without escapes. - from_token (tuple[int, int]|None): The start point for the pagination. - This is an exclusive minimum bound if direction is "f", and an - inclusive maximum bound if direction is "b". - to_token (tuple[int, int]|None): The endpoint point for the pagination. - This is an inclusive maximum bound if direction is "f", and an - exclusive minimum bound if direction is "b". + direction: Whether we're paginating backwards("b") or forwards ("f"). + column_names: The column names to bound. Must *not* be user defined as + these get inserted directly into the SQL statement without escapes. + from_token: The start point for the pagination. This is an exclusive + minimum bound if direction is "f", and an inclusive maximum bound if + direction is "b". + to_token: The endpoint point for the pagination. This is an inclusive + maximum bound if direction is "f", and an exclusive minimum bound if + direction is "b". engine: The database engine to generate the clauses for Returns: - str: The sql expression + The sql expression """ assert direction in ("b", "f") @@ -132,7 +136,12 @@ def generate_pagination_where_clause( return " AND ".join(where_clause) -def _make_generic_sql_bound(bound, column_names, values, engine): +def _make_generic_sql_bound( + bound: str, + column_names: Tuple[str, str], + values: Tuple[Optional[int], int], + engine: BaseDatabaseEngine, +) -> str: """Create an SQL expression that bounds the given column names by the values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. @@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine): out manually. Args: - bound (str): The comparison operator to use. One of ">", "<", ">=", + bound: The comparison operator to use. One of ">", "<", ">=", "<=", where the values are on the left and columns on the right. - names (tuple[str, str]): The column names. Must *not* be user defined + names: The column names. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - values (tuple[int|None, int]): The values to bound the columns by. If + values: The values to bound the columns by. If the first value is None then only creates a bound on the second column. engine: The database engine to generate the SQL for Returns: - str + The SQL statement """ assert bound in (">", "<", ">=", "<=") @@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): ) -def filter_to_clause(event_filter): +def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create # "room_id == X AND room_id != X", which postgres doesn't optimise. @@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def get_room_min_stream_ordering(self): raise NotImplementedError() - @defer.inlineCallbacks - def get_room_events_stream_for_rooms( - self, room_ids, from_key, to_key, limit=0, order="DESC" - ): + async def get_room_events_stream_for_rooms( + self, + room_ids: Iterable[str], + from_key: str, + to_key: str, + limit: int = 0, + order: str = "DESC", + ) -> Dict[str, Tuple[List[EventBase], str]]: """Get new room events in stream ordering since `from_key`. Args: - room_id (str) - from_key (str): Token from which no events are returned before - to_key (str): Token from which no events are returned after. (This + room_ids + from_key: Token from which no events are returned before + to_key: Token from which no events are returned after. (This is typically the current stream token) - limit (int): Maximum number of events to return - order (str): Either "DESC" or "ASC". Determines which events are + limit: Maximum number of events to return + order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: - Deferred[dict[str,tuple[list[FrozenEvent], str]]] - A map from room id to a tuple containing: - - list of recent events in the room - - stream ordering key for the start of the chunk of events returned. + A map from room id to a tuple containing: + - list of recent events in the room + - stream ordering key for the start of the chunk of events returned. """ from_id = RoomStreamToken.parse_stream_token(from_key).stream - room_ids = yield self._events_stream_cache.get_entities_changed( - room_ids, from_id - ) + room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) if not room_ids: return {} @@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): - res = yield make_deferred_yieldable( + res = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -361,28 +371,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if self._events_stream_cache.has_entity_changed(room_id, from_key) } - @defer.inlineCallbacks - def get_room_events_stream_for_room( - self, room_id, from_key, to_key, limit=0, order="DESC" - ): + async def get_room_events_stream_for_room( + self, + room_id: str, + from_key: str, + to_key: str, + limit: int = 0, + order: str = "DESC", + ) -> Tuple[List[EventBase], str]: """Get new room events in stream ordering since `from_key`. Args: - room_id (str) - from_key (str): Token from which no events are returned before - to_key (str): Token from which no events are returned after. (This + room_id + from_key: Token from which no events are returned before + to_key: Token from which no events are returned after. (This is typically the current stream token) - limit (int): Maximum number of events to return - order (str): Either "DESC" or "ASC". Determines which events are + limit: Maximum number of events to return + order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns the list of - events (in ascending order) and the token from the start of - the chunk of events returned. + The list of events (in ascending order) and the token from the start + of the chunk of events returned. """ if from_key == to_key: return [], from_key @@ -390,9 +403,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream - has_changed = yield self._events_stream_cache.has_entity_changed( - room_id, from_id - ) + has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) if not has_changed: return [], from_key @@ -410,9 +421,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f) + rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -430,8 +441,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key - @defer.inlineCallbacks - def get_membership_changes_for_user(self, user_id, from_key, to_key): + async def get_membership_changes_for_user(self, user_id, from_key, to_key): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -460,9 +470,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows - rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f) + rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -470,27 +480,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token): + async def get_recent_events_for_room( + self, room_id: str, limit: int, end_token: str + ) -> Tuple[List[EventBase], str]: """Get the most recent events in the room in topological ordering. Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. + room_id + limit + end_token: The stream token representing now. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of - events and a token pointing to the start of the returned - events. - The events returned are in ascending order. + A list of events and a token pointing to the start of the returned + events. The events returned are in ascending order. """ - rows, token = yield self.get_recent_event_ids_for_room( + rows, token = await self.get_recent_event_ids_for_room( room_id, limit, end_token ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -498,20 +507,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return (events, token) - @defer.inlineCallbacks - def get_recent_event_ids_for_room(self, room_id, limit, end_token): + async def get_recent_event_ids_for_room( + self, room_id: str, limit: int, end_token: str + ) -> Tuple[List[_EventDictReturn], str]: """Get the most recent events in the room in topological ordering. Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. + room_id + limit + end_token: The stream token representing now. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of - _EventDictReturn and a token pointing to the start of the returned - events. - The events returned are in ascending order. + A list of _EventDictReturn and a token pointing to the start of the + returned events. The events returned are in ascending order. """ # Allow a zero limit here, and no-op. if limit == 0: @@ -519,7 +527,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.db_pool.runInteraction( + rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -532,12 +540,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, token - def get_room_event_before_stream_ordering(self, room_id, stream_ordering): + def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int): """Gets details of the first event in a room at or before a stream ordering Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: Deferred[(int, int, str)]: @@ -574,55 +582,56 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) return "t%d-%d" % (topo, token) - def get_stream_token_for_event(self, event_id): + async def get_stream_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: - event_id(str): The id of the event to look up a stream token for. + event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A deferred "s%d" stream token. + A "s%d" stream token. """ - return self.db_pool.simple_select_one_onecol( + row = await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" - ).addCallback(lambda row: "s%d" % (row,)) + ) + return "s%d" % (row,) - def get_topological_token_for_event(self, event_id): + async def get_topological_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: - event_id(str): The id of the event to look up a stream token for. + event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A deferred "t%d-%d" topological token. + A "t%d-%d" topological token. """ - return self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", - ).addCallback( - lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) ) + return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) - def get_max_topological_token(self, room_id, stream_key): + async def get_max_topological_token(self, room_id: str, stream_key: int) -> int: """Get the max topological token in a room before the given stream ordering. Args: - room_id (str) - stream_key (int) + room_id + stream_key Returns: - Deferred[int] + The maximum topological token. """ sql = ( "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self.db_pool.execute( + row = await self.db_pool.execute( "get_max_topological_token", None, sql, room_id, stream_key - ).addCallback(lambda r: r[0][0] if r else 0) + ) + return row[0][0] if row else 0 def _get_max_topological_txn(self, txn, room_id): txn.execute( @@ -634,16 +643,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows[0][0] if rows else 0 @staticmethod - def _set_before_and_after(events, rows, topo_order=True): + def _set_before_and_after( + events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True + ): """Inserts ordering information to events' internal metadata from the DB rows. Args: - events (list[FrozenEvent]) - rows (list[_EventDictReturn]) - topo_order (bool): Whether the events were ordered topologically - or by stream ordering. If true then all rows should have a non - null topological_ordering. + events + rows + topo_order: Whether the events were ordered topologically or by stream + ordering. If true then all rows should have a non null + topological_ordering. """ for event, row in zip(events, rows): stream = row.stream_ordering @@ -656,25 +667,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): internal.after = str(RoomStreamToken(topo, stream)) internal.order = (int(topo) if topo else 0, int(stream)) - @defer.inlineCallbacks - def get_events_around( - self, room_id, event_id, before_limit, after_limit, event_filter=None - ): + async def get_events_around( + self, + room_id: str, + event_id: str, + before_limit: int, + after_limit: int, + event_filter: Optional[Filter] = None, + ) -> dict: """Retrieve events and pagination tokens around a given event in a room. - - Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) - - Returns: - dict """ - results = yield self.db_pool.runInteraction( + results = await self.db_pool.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -684,11 +689,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events_before = yield self.get_events_as_list( + events_before = await self.get_events_as_list( list(results["before"]["event_ids"]), get_prev_content=True ) - events_after = yield self.get_events_as_list( + events_after = await self.get_events_as_list( list(results["after"]["event_ids"]), get_prev_content=True ) @@ -700,17 +705,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): } def _get_events_around_txn( - self, txn, room_id, event_id, before_limit, after_limit, event_filter - ): + self, + txn, + room_id: str, + event_id: str, + before_limit: int, + after_limit: int, + event_filter: Optional[Filter], + ) -> dict: """Retrieves event_ids and pagination tokens around a given event in a room. Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) + room_id + event_id + before_limit + after_limit + event_filter Returns: dict @@ -758,22 +769,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - @defer.inlineCallbacks - def get_all_new_events_stream(self, from_id, current_id, limit): + async def get_all_new_events_stream( + self, from_id: int, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: """Get all new events Returns all events with from_id < stream_ordering <= current_id. Args: - from_id (int): the stream_ordering of the last event we processed - current_id (int): the stream_ordering of the most recently processed event - limit (int): the maximum number of events to return + from_id: the stream_ordering of the last event we processed + current_id: the stream_ordering of the most recently processed event + limit: the maximum number of events to return Returns: - Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where - `next_id` is the next value to pass as `from_id` (it will either be the - stream_ordering of the last returned event, or, if fewer than `limit` events - were found, `current_id`. + A tuple of (next_id, events), where `next_id` is the next value to + pass as `from_id` (it will either be the stream_ordering of the + last returned event, or, if fewer than `limit` events were found, + the `current_id`). """ def get_all_new_events_stream_txn(txn): @@ -795,11 +807,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db_pool.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events @@ -817,21 +829,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_federation_out_pos", ) - async def update_federation_out_pos(self, typ, stream_id): + async def update_federation_out_pos(self, typ: str, stream_id: int) -> None: if self._need_to_reset_federation_stream_positions: await self.db_pool.runInteraction( "_reset_federation_positions_txn", self._reset_federation_positions_txn ) self._need_to_reset_federation_stream_positions = False - return await self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="federation_stream_position", keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) - def _reset_federation_positions_txn(self, txn): + def _reset_federation_positions_txn(self, txn) -> None: """Fiddles with the `federation_stream_position` table to make it match the configured federation sender instances during start up. """ @@ -892,39 +904,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): values={"stream_id": stream_id}, ) - def has_room_changed_since(self, room_id, stream_id): + def has_room_changed_since(self, room_id: str, stream_id: int) -> bool: return self._events_stream_cache.has_entity_changed(room_id, stream_id) def _paginate_room_events_txn( self, txn, - room_id, - from_token, - to_token=None, - direction="b", - limit=-1, - event_filter=None, - ): + room_id: str, + from_token: RoomStreamToken, + to_token: Optional[RoomStreamToken] = None, + direction: str = "b", + limit: int = -1, + event_filter: Optional[Filter] = None, + ) -> Tuple[List[_EventDictReturn], str]: """Returns list of events before or after a given token. Args: txn - room_id (str) - from_token (RoomStreamToken): The token used to stream from - to_token (RoomStreamToken|None): A token which if given limits the - results to only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to + room_id + from_token: The token used to stream from + to_token: A token which if given limits the results to only those before + direction: Either 'b' or 'f' to indicate whether we are paginating + forwards or backwards from `from_key`. + limit: The maximum number of events to return. + event_filter: If provided filters the events to those that match the filter. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns the results - as a list of _EventDictReturn and a token that points to the end - of the result set. If no events are returned then the end of the - stream has been reached (i.e. there are no events between - `from_token` and `to_token`), or `limit` is zero. + A list of _EventDictReturn and a token that points to the end of the + result set. If no events are returned then the end of the stream has + been reached (i.e. there are no events between `from_token` and + `to_token`), or `limit` is zero. """ assert int(limit) >= 0 @@ -1008,35 +1018,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, str(next_token) - @defer.inlineCallbacks - def paginate_room_events( - self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None - ): + async def paginate_room_events( + self, + room_id: str, + from_key: str, + to_key: Optional[str] = None, + direction: str = "b", + limit: int = -1, + event_filter: Optional[Filter] = None, + ) -> Tuple[List[EventBase], str]: """Returns list of events before or after a given token. Args: - room_id (str) - from_key (str): The token used to stream from - to_key (str|None): A token which if given limits the results to - only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to - those that match the filter. + room_id + from_key: The token used to stream from + to_key: A token which if given limits the results to only those before + direction: Either 'b' or 'f' to indicate whether we are paginating + forwards or backwards from `from_key`. + limit: The maximum number of events to return. + event_filter: If provided filters the events to those that match the filter. Returns: - tuple[list[FrozenEvent], str]: Returns the results as a list of - events and a token that points to the end of the result set. If no - events are returned then the end of the stream has been reached - (i.e. there are no events between `from_key` and `to_key`). + The results as a list of events and a token that points to the end + of the result set. If no events are returned then the end of the + stream has been reached (i.e. there are no events between `from_key` + and `to_key`). """ from_key = RoomStreamToken.parse(from_key) if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.db_pool.runInteraction( + rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, @@ -1047,7 +1060,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -1057,8 +1070,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): class StreamStore(StreamWorkerStore): - def get_room_max_stream_ordering(self): + def get_room_max_stream_ordering(self) -> int: return self._stream_id_gen.get_current_token() - def get_room_min_stream_ordering(self): + def get_room_min_stream_ordering(self) -> int: return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py deleted file mode 100644 index 18a462f0ee..0000000000 --- a/synapse/storage/presence.py +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import namedtuple - -from synapse.api.constants import PresenceState - - -class UserPresenceState( - namedtuple( - "UserPresenceState", - ( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", - ), - ) -): - """Represents the current presence state of the user. - - user_id (str) - last_active (int): Time in msec that the user last interacted with server. - last_federation_update (int): Time in msec since either a) we sent a presence - update to other servers or b) we received a presence update, depending - on if is a local user or not. - last_user_sync (int): Time in msec that the user last *completed* a sync - (or event stream). - status_msg (str): User set status message. - """ - - def as_dict(self): - return dict(self._asdict()) - - @staticmethod - def from_dict(d): - return UserPresenceState(**d) - - def copy_and_replace(self, **kwargs): - return self._replace(**kwargs) - - @classmethod - def default(cls, user_id): - """Returns a default presence state. - """ - return cls( - user_id=user_id, - state=PresenceState.OFFLINE, - last_active_ts=0, - last_federation_update_ts=0, - last_user_sync_ts=0, - status_msg=None, - currently_active=False, - ) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 05ea40a7de..306dcfe944 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -19,6 +19,7 @@ from mock import Mock, call from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState +from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events.builder import EventBuilder from synapse.handlers.presence import ( @@ -32,7 +33,6 @@ from synapse.handlers.presence import ( handle_update, ) from synapse.rest.client.v1 import room -from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from tests import unittest diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index a6012c973d..918387733b 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -15,6 +15,7 @@ from twisted.internet import defer +from synapse.api.errors import NotFoundError from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -46,30 +47,19 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_storage() # Get the topological token - event = store.get_topological_token_for_event(last["event_id"]) - self.pump() - event = self.successResultOf(event) - - # Purge everything before this topological token - purge = defer.ensureDeferred( - storage.purge_events.purge_history(self.room_id, event, True) + event = self.get_success( + store.get_topological_token_for_event(last["event_id"]) ) - self.pump() - self.assertEqual(self.successResultOf(purge), None) - # Try and get the events - get_first = store.get_event(first["event_id"]) - get_second = store.get_event(second["event_id"]) - get_third = store.get_event(third["event_id"]) - get_last = store.get_event(last["event_id"]) - self.pump() + # Purge everything before this topological token + self.get_success(storage.purge_events.purge_history(self.room_id, event, True)) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted # and last is not. - self.failureResultOf(get_first) - self.failureResultOf(get_second) - self.failureResultOf(get_third) - self.successResultOf(get_last) + self.get_failure(store.get_event(first["event_id"]), NotFoundError) + self.get_failure(store.get_event(second["event_id"]), NotFoundError) + self.get_failure(store.get_event(third["event_id"]), NotFoundError) + self.get_success(store.get_event(last["event_id"])) def test_purge_wont_delete_extrems(self): """ @@ -84,9 +74,9 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_datastore() # Set the topological token higher than it should be - event = storage.get_topological_token_for_event(last["event_id"]) - self.pump() - event = self.successResultOf(event) + event = self.get_success( + storage.get_topological_token_for_event(last["event_id"]) + ) event = "t{}-{}".format( *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) ) @@ -98,14 +88,7 @@ class PurgeTests(HomeserverTestCase): self.assertIn("greater than forward", f.value.args[0]) # Try and get the events - get_first = storage.get_event(first["event_id"]) - get_second = storage.get_event(second["event_id"]) - get_third = storage.get_event(third["event_id"]) - get_last = storage.get_event(last["event_id"]) - self.pump() - - # Nothing is deleted. - self.successResultOf(get_first) - self.successResultOf(get_second) - self.successResultOf(get_third) - self.successResultOf(get_last) + self.get_success(storage.get_event(first["event_id"])) + self.get_success(storage.get_event(second["event_id"])) + self.get_success(storage.get_event(third["event_id"])) + self.get_success(storage.get_event(last["event_id"])) -- cgit 1.5.1 From e04e465b4d2c66acb8885c31736c7b7bb4e7be52 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 17 Aug 2020 17:05:00 +0100 Subject: Use the default templates when a custom template file cannot be found (#8037) Fixes https://github.com/matrix-org/synapse/issues/6583 --- changelog.d/8037.feature | 1 + docs/sample_config.yaml | 4 +- synapse/config/_base.py | 100 ++++++++++++++++++++- synapse/config/emailconfig.py | 145 ++++++++++++++----------------- synapse/config/saml2_config.py | 14 +-- synapse/config/sso.py | 37 ++++---- synapse/handlers/account_validity.py | 20 +---- synapse/handlers/auth.py | 12 ++- synapse/handlers/oidc_handler.py | 5 +- synapse/push/mailer.py | 72 +-------------- synapse/push/pusher.py | 31 ++----- synapse/python_dependencies.py | 2 - synapse/rest/client/v2_alpha/account.py | 44 +++------- synapse/rest/client/v2_alpha/register.py | 31 ++----- tests/config/test_base.py | 82 +++++++++++++++++ 15 files changed, 310 insertions(+), 290 deletions(-) create mode 100644 changelog.d/8037.feature create mode 100644 tests/config/test_base.py (limited to 'tests') diff --git a/changelog.d/8037.feature b/changelog.d/8037.feature new file mode 100644 index 0000000000..2e5127477d --- /dev/null +++ b/changelog.d/8037.feature @@ -0,0 +1 @@ +Use the default template file when its equivalent is not found in a custom template directory. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 9235b89fb1..f168853f67 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2002,9 +2002,7 @@ email: # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # Do not uncomment this setting unless you want to customise the templates. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index fd137853b1..1417487427 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,12 +18,16 @@ import argparse import errno import os +import time +import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, List, MutableMapping, Optional +from typing import Any, Callable, List, MutableMapping, Optional import attr +import jinja2 +import pkg_resources import yaml @@ -100,6 +104,11 @@ class Config(object): def __init__(self, root_config=None): self.root = root_config + # Get the path to the default Synapse template directory + self.default_template_dir = pkg_resources.resource_filename( + "synapse", "res/templates" + ) + def __getattr__(self, item: str) -> Any: """ Try and fetch a configuration option that does not exist on this class. @@ -184,6 +193,95 @@ class Config(object): with open(file_path) as file_stream: return file_stream.read() + def read_templates( + self, filenames: List[str], custom_template_directory: Optional[str] = None, + ) -> List[jinja2.Template]: + """Load a list of template files from disk using the given variables. + + This function will attempt to load the given templates from the default Synapse + template directory. If `custom_template_directory` is supplied, that directory + is tried first. + + Files read are treated as Jinja templates. These templates are not rendered yet. + + Args: + filenames: A list of template filenames to read. + + custom_template_directory: A directory to try to look for the templates + before using the default Synapse template directory instead. + + Raises: + ConfigError: if the file's path is incorrect or otherwise cannot be read. + + Returns: + A list of jinja2 templates. + """ + templates = [] + search_directories = [self.default_template_dir] + + # The loader will first look in the custom template directory (if specified) for the + # given filename. If it doesn't find it, it will use the default template dir instead + if custom_template_directory: + # Check that the given template directory exists + if not self.path_exists(custom_template_directory): + raise ConfigError( + "Configured template directory does not exist: %s" + % (custom_template_directory,) + ) + + # Search the custom template directory as well + search_directories.insert(0, custom_template_directory) + + loader = jinja2.FileSystemLoader(search_directories) + env = jinja2.Environment(loader=loader, autoescape=True) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + } + ) + + for filename in filenames: + # Load the template + template = env.get_template(filename) + templates.append(template) + + return templates + + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) + + +def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + class RootConfig(object): """ diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index a63acbdc63..7a796996c0 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -23,7 +23,6 @@ from enum import Enum from typing import Optional import attr -import pkg_resources from ._base import Config, ConfigError @@ -98,21 +97,18 @@ class EmailConfig(Config): if parsed[1] == "": raise RuntimeError("Invalid notif_from address") + # A user-configurable template directory template_dir = email_config.get("template_dir") - # we need an absolute path, because we change directory after starting (and - # we don't yet know what auxiliary templates like mail.css we will need). - # (Note that loading as package_resources with jinja.PackageLoader doesn't - # work for the same reason.) - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates") - - self.email_template_dir = os.path.abspath(template_dir) + if isinstance(template_dir, str): + # We need an absolute path, because we change directory after starting (and + # we don't yet know what auxiliary templates like mail.css we will need). + template_dir = os.path.abspath(template_dir) + elif template_dir is not None: + # If template_dir is something other than a str or None, warn the user + raise ConfigError("Config option email.template_dir must be type str") self.email_enable_notifs = email_config.get("enable_notifs", False) - account_validity_config = config.get("account_validity") or {} - account_validity_renewal_enabled = account_validity_config.get("renew_at") - self.threepid_behaviour_email = ( # Have Synapse handle the email sending if account_threepid_delegates.email # is not defined @@ -166,19 +162,6 @@ class EmailConfig(Config): email_config.get("validation_token_lifetime", "1h") ) - if ( - self.email_enable_notifs - or account_validity_renewal_enabled - or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL - ): - # make sure we can import the required deps - import bleach - import jinja2 - - # prevent unused warnings - jinja2 - bleach - if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL: missing = [] if not self.email_notif_from: @@ -196,49 +179,49 @@ class EmailConfig(Config): # These email templates have placeholders in them, and thus must be # parsed using a templating engine during a request - self.email_password_reset_template_html = email_config.get( + password_reset_template_html = email_config.get( "password_reset_template_html", "password_reset.html" ) - self.email_password_reset_template_text = email_config.get( + password_reset_template_text = email_config.get( "password_reset_template_text", "password_reset.txt" ) - self.email_registration_template_html = email_config.get( + registration_template_html = email_config.get( "registration_template_html", "registration.html" ) - self.email_registration_template_text = email_config.get( + registration_template_text = email_config.get( "registration_template_text", "registration.txt" ) - self.email_add_threepid_template_html = email_config.get( + add_threepid_template_html = email_config.get( "add_threepid_template_html", "add_threepid.html" ) - self.email_add_threepid_template_text = email_config.get( + add_threepid_template_text = email_config.get( "add_threepid_template_text", "add_threepid.txt" ) - self.email_password_reset_template_failure_html = email_config.get( + password_reset_template_failure_html = email_config.get( "password_reset_template_failure_html", "password_reset_failure.html" ) - self.email_registration_template_failure_html = email_config.get( + registration_template_failure_html = email_config.get( "registration_template_failure_html", "registration_failure.html" ) - self.email_add_threepid_template_failure_html = email_config.get( + add_threepid_template_failure_html = email_config.get( "add_threepid_template_failure_html", "add_threepid_failure.html" ) # These templates do not support any placeholder variables, so we # will read them from disk once during setup - email_password_reset_template_success_html = email_config.get( + password_reset_template_success_html = email_config.get( "password_reset_template_success_html", "password_reset_success.html" ) - email_registration_template_success_html = email_config.get( + registration_template_success_html = email_config.get( "registration_template_success_html", "registration_success.html" ) - email_add_threepid_template_success_html = email_config.get( + add_threepid_template_success_html = email_config.get( "add_threepid_template_success_html", "add_threepid_success.html" ) - # Check templates exist - for f in [ + # Read all templates from disk + ( self.email_password_reset_template_html, self.email_password_reset_template_text, self.email_registration_template_html, @@ -248,32 +231,36 @@ class EmailConfig(Config): self.email_password_reset_template_failure_html, self.email_registration_template_failure_html, self.email_add_threepid_template_failure_html, - email_password_reset_template_success_html, - email_registration_template_success_html, - email_add_threepid_template_success_html, - ]: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find template file %s" % (p,)) - - # Retrieve content of web templates - filepath = os.path.join( - self.email_template_dir, email_password_reset_template_success_html + password_reset_template_success_html_template, + registration_template_success_html_template, + add_threepid_template_success_html_template, + ) = self.read_templates( + [ + password_reset_template_html, + password_reset_template_text, + registration_template_html, + registration_template_text, + add_threepid_template_html, + add_threepid_template_text, + password_reset_template_failure_html, + registration_template_failure_html, + add_threepid_template_failure_html, + password_reset_template_success_html, + registration_template_success_html, + add_threepid_template_success_html, + ], + template_dir, ) - self.email_password_reset_template_success_html = self.read_file( - filepath, "email.password_reset_template_success_html" - ) - filepath = os.path.join( - self.email_template_dir, email_registration_template_success_html - ) - self.email_registration_template_success_html_content = self.read_file( - filepath, "email.registration_template_success_html" + + # Render templates that do not contain any placeholders + self.email_password_reset_template_success_html_content = ( + password_reset_template_success_html_template.render() ) - filepath = os.path.join( - self.email_template_dir, email_add_threepid_template_success_html + self.email_registration_template_success_html_content = ( + registration_template_success_html_template.render() ) - self.email_add_threepid_template_success_html_content = self.read_file( - filepath, "email.add_threepid_template_success_html" + self.email_add_threepid_template_success_html_content = ( + add_threepid_template_success_html_template.render() ) if self.email_enable_notifs: @@ -290,17 +277,19 @@ class EmailConfig(Config): % (", ".join(missing),) ) - self.email_notif_template_html = email_config.get( + notif_template_html = email_config.get( "notif_template_html", "notif_mail.html" ) - self.email_notif_template_text = email_config.get( + notif_template_text = email_config.get( "notif_template_text", "notif_mail.txt" ) - for f in self.email_notif_template_text, self.email_notif_template_html: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p,)) + ( + self.email_notif_template_html, + self.email_notif_template_text, + ) = self.read_templates( + [notif_template_html, notif_template_text], template_dir, + ) self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True @@ -309,18 +298,20 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if account_validity_renewal_enabled: - self.email_expiry_template_html = email_config.get( + if self.account_validity.renew_by_email_enabled: + expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) - self.email_expiry_template_text = email_config.get( + expiry_template_text = email_config.get( "expiry_template_text", "notice_expiry.txt" ) - for f in self.email_expiry_template_text, self.email_expiry_template_html: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p,)) + ( + self.account_validity_template_html, + self.account_validity_template_text, + ) = self.read_templates( + [expiry_template_html, expiry_template_text], template_dir, + ) subjects_config = email_config.get("subjects", {}) subjects = {} @@ -400,9 +391,7 @@ class EmailConfig(Config): # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # Do not uncomment this setting unless you want to customise the templates. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 9277b5f342..036f8c0e90 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -18,8 +18,6 @@ import logging from typing import Any, List import attr -import jinja2 -import pkg_resources from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module @@ -171,15 +169,9 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "15m") ) - template_dir = saml2_config.get("template_dir") - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - - loader = jinja2.FileSystemLoader(template_dir) - # enable auto-escape here, to having to remember to escape manually in the - # template - env = jinja2.Environment(loader=loader, autoescape=True) - self.saml2_error_html_template = env.get_template("saml_error.html") + self.saml2_error_html_template = self.read_templates( + ["saml_error.html"], saml2_config.get("template_dir") + ) def _default_saml_config_dict( self, required_attributes: set, optional_attributes: set diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 73b7296399..4427676167 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -12,11 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any, Dict -import pkg_resources - from ._base import Config @@ -29,22 +26,32 @@ class SSOConfig(Config): def read_config(self, config, **kwargs): sso_config = config.get("sso") or {} # type: Dict[str, Any] - # Pick a template directory in order of: - # * The sso-specific template_dir - # * /path/to/synapse/install/res/templates + # The sso-specific template_dir template_dir = sso_config.get("template_dir") - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - self.sso_template_dir = template_dir - self.sso_account_deactivated_template = self.read_file( - os.path.join(self.sso_template_dir, "sso_account_deactivated.html"), - "sso_account_deactivated_template", + # Read templates from disk + ( + self.sso_redirect_confirm_template, + self.sso_auth_confirm_template, + self.sso_error_template, + sso_account_deactivated_template, + sso_auth_success_template, + ) = self.read_templates( + [ + "sso_redirect_confirm.html", + "sso_auth_confirm.html", + "sso_error.html", + "sso_account_deactivated.html", + "sso_auth_success.html", + ], + template_dir, ) - self.sso_auth_success_template = self.read_file( - os.path.join(self.sso_template_dir, "sso_auth_success.html"), - "sso_auth_success_template", + + # These templates have no placeholders, so render them here + self.sso_account_deactivated_template = ( + sso_account_deactivated_template.render() ) + self.sso_auth_success_template = sso_auth_success_template.render() self.sso_client_whitelist = sso_config.get("client_whitelist") or [] diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 590135d19c..b865bf5b48 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -26,11 +26,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID from synapse.util import stringutils -try: - from synapse.push.mailer import load_jinja2_templates -except ImportError: - load_jinja2_templates = None - logger = logging.getLogger(__name__) @@ -47,9 +42,11 @@ class AccountValidityHandler(object): if ( self._account_validity.enabled and self._account_validity.renew_by_email_enabled - and load_jinja2_templates ): # Don't do email-specific configuration if renewal by email is disabled. + self._template_html = self.config.account_validity_template_html + self._template_text = self.config.account_validity_template_text + try: app_name = self.hs.config.email_app_name @@ -65,17 +62,6 @@ class AccountValidityHandler(object): self._raw_from = email.utils.parseaddr(self._from_string)[1] - self._template_html, self._template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_expiry_template_html, - self.config.email_expiry_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) - # Check the renewal emails to send and send them every 30min. def send_emails(): # run as a background process to make sure that the database transactions diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c24e7bafe0..68d6870e40 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -42,7 +42,6 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi -from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.threepids import canonicalise_email @@ -132,18 +131,17 @@ class AuthHandler(BaseHandler): # after the SSO completes and before redirecting them back to their client. # It notifies the user they are about to give access to their matrix account # to the client. - self._sso_redirect_confirm_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_redirect_confirm.html"], - )[0] + self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template + # The following template is shown during user interactive authentication # in the fallback auth scenario. It notifies the user that they are # authenticating for an operation to occur on their account. - self._sso_auth_confirm_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_auth_confirm.html"], - )[0] + self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template + # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. self._sso_auth_success_template = hs.config.sso_auth_success_template + # The following template is shown during the SSO authentication process if # the account is deactivated. self._sso_account_deactivated_template = ( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index fa5ee5de8f..87d28a7ae9 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -38,7 +38,6 @@ from synapse.config import ConfigError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.push.mailer import load_jinja2_templates from synapse.types import UserID, map_username_to_mxid_localpart if TYPE_CHECKING: @@ -123,9 +122,7 @@ class OidcHandler: self._hostname = hs.hostname # type: str self._server_name = hs.config.server_name # type: str self._macaroon_secret_key = hs.config.macaroon_secret_key - self._error_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_error.html"] - )[0] + self._error_template = hs.config.sso_error_template # identifier for the external_ids table self._auth_provider_id = "oidc" diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index af117fddf9..c38e037281 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -16,8 +16,7 @@ import email.mime.multipart import email.utils import logging -import time -import urllib +import urllib.parse from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import Iterable, List, TypeVar @@ -640,72 +639,3 @@ def string_ordinal_total(s): for c in s: tot += ord(c) return tot - - -def format_ts_filter(value, format): - return time.strftime(format, time.localtime(value / 1000)) - - -def load_jinja2_templates( - template_dir, - template_filenames, - apply_format_ts_filter=False, - apply_mxc_to_http_filter=False, - public_baseurl=None, -): - """Loads and returns one or more jinja2 templates and applies optional filters - - Args: - template_dir (str): The directory where templates are stored - template_filenames (list[str]): A list of template filenames - apply_format_ts_filter (bool): Whether to apply a template filter that formats - timestamps - apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts - mxc urls to http urls - public_baseurl (str|None): The public baseurl of the server. Required for - apply_mxc_to_http_filter to be enabled - - Returns: - A list of jinja2 templates corresponding to the given list of filenames, - with order preserved - """ - logger.info( - "loading email templates %s from '%s'", template_filenames, template_dir - ) - loader = jinja2.FileSystemLoader(template_dir) - env = jinja2.Environment(loader=loader) - - if apply_format_ts_filter: - env.filters["format_ts"] = format_ts_filter - - if apply_mxc_to_http_filter and public_baseurl: - env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl) - - templates = [] - for template_filename in template_filenames: - template = env.get_template(template_filename) - templates.append(template) - - return templates - - -def _create_mxc_to_http_filter(public_baseurl): - def mxc_to_http_filter(value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - serverAndMediaId = value[6:] - fragment = None - if "#" in serverAndMediaId: - (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1) - fragment = "#" + fragment - - params = {"width": width, "height": height, "method": resize_method} - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - public_baseurl, - serverAndMediaId, - urllib.parse.urlencode(params), - fragment or "", - ) - - return mxc_to_http_filter diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 8ad0bf5936..f626797133 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -15,22 +15,13 @@ import logging +from synapse.push.emailpusher import EmailPusher +from synapse.push.mailer import Mailer + from .httppusher import HttpPusher logger = logging.getLogger(__name__) -# We try importing this if we can (it will fail if we don't -# have the optional email dependencies installed). We don't -# yet have the config to know if we need the email pusher, -# but importing this after daemonizing seems to fail -# (even though a simple test of importing from a daemonized -# process works fine) -try: - from synapse.push.emailpusher import EmailPusher - from synapse.push.mailer import Mailer, load_jinja2_templates -except Exception: - pass - class PusherFactory(object): def __init__(self, hs): @@ -43,16 +34,8 @@ class PusherFactory(object): if hs.config.email_enable_notifs: self.mailers = {} # app_name -> Mailer - self.notif_template_html, self.notif_template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_notif_template_html, - self.config.email_notif_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) + self._notif_template_html = hs.config.email_notif_template_html + self._notif_template_text = hs.config.email_notif_template_text self.pusher_types["email"] = self._create_email_pusher @@ -73,8 +56,8 @@ class PusherFactory(object): mailer = Mailer( hs=self.hs, app_name=app_name, - template_html=self.notif_template_html, - template_text=self.notif_template_text, + template_html=self._notif_template_html, + template_text=self._notif_template_text, ) self.mailers[app_name] = mailer return EmailPusher(self.hs, pusherdict, mailer) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index e5f22fb858..3250d41dde 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -78,8 +78,6 @@ CONDITIONAL_REQUIREMENTS = { "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], # we use execute_batch, which arrived in psycopg 2.7. "postgres": ["psycopg2>=2.7"], - # ConsentResource uses select_autoescape, which arrived in jinja 2.9 - "resources.consent": ["Jinja2>=2.9"], # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. "acme": [ diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index fead85074b..203e76b9f2 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -32,7 +32,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import Mailer, load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import canonicalise_email, check_3pid_allowed @@ -53,21 +53,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_password_reset_template_html, - self.config.email_password_reset_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_password_reset_template_html, + template_text=self.config.email_password_reset_template_text, ) async def on_POST(self, request): @@ -169,9 +159,8 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_password_reset_template_failure_html], + self._failure_email_template = ( + self.config.email_password_reset_template_failure_html ) async def on_GET(self, request, medium): @@ -214,14 +203,14 @@ class PasswordResetSubmitTokenServlet(RestServlet): return None # Otherwise show the success template - html = self.config.email_password_reset_template_success_html + html = self.config.email_password_reset_template_success_html_content status_code = 200 except ThreepidValidationError as e: status_code = e.code # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) @@ -411,19 +400,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_add_threepid_template_html, - self.config.email_add_threepid_template_text, - ], - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_add_threepid_template_html, + template_text=self.config.email_add_threepid_template_text, ) async def on_POST(self, request): @@ -578,9 +559,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_add_threepid_template_failure_html], + self._failure_email_template = ( + self.config.email_add_threepid_template_failure_html ) async def on_GET(self, request): @@ -631,7 +611,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index f808175698..7290fd0756 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -44,7 +44,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -81,23 +81,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): self.config = hs.config if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - from synapse.push.mailer import Mailer, load_jinja2_templates - - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_registration_template_html, - self.config.email_registration_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_registration_template_html, + template_text=self.config.email_registration_template_text, ) async def on_POST(self, request): @@ -262,15 +250,8 @@ class RegistrationSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], - ) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], + self._failure_email_template = ( + self.config.email_registration_template_failure_html ) async def on_GET(self, request, medium): @@ -318,7 +299,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) diff --git a/tests/config/test_base.py b/tests/config/test_base.py new file mode 100644 index 0000000000..42ee5f56d9 --- /dev/null +++ b/tests/config/test_base.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path +import tempfile + +from synapse.config import ConfigError +from synapse.util.stringutils import random_string + +from tests import unittest + + +class BaseConfigTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.hs = hs + + def test_loading_missing_templates(self): + # Use a temporary directory that exists on the system, but that isn't likely to + # contain template files + with tempfile.TemporaryDirectory() as tmp_dir: + # Attempt to load an HTML template from our custom template directory + template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0] + + # If no errors, we should've gotten the default template instead + + # Render the template + a_random_string = random_string(5) + html_content = template.render({"error_description": a_random_string}) + + # Check that our string exists in the template + self.assertIn( + a_random_string, + html_content, + "Template file did not contain our test string", + ) + + def test_loading_custom_templates(self): + # Use a temporary directory that exists on the system + with tempfile.TemporaryDirectory() as tmp_dir: + # Create a temporary bogus template file + with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template: + # Get temporary file's filename + template_filename = os.path.basename(tmp_template.name) + + # Write a custom HTML template + contents = b"{{ test_variable }}" + tmp_template.write(contents) + tmp_template.flush() + + # Attempt to load the template from our custom template directory + template = ( + self.hs.config.read_templates([template_filename], tmp_dir) + )[0] + + # Render the template + a_random_string = random_string(5) + html_content = template.render({"test_variable": a_random_string}) + + # Check that our string exists in the template + self.assertIn( + a_random_string, + html_content, + "Template file did not contain our test string", + ) + + def test_loading_template_from_nonexistent_custom_directory(self): + with self.assertRaises(ConfigError): + self.hs.config.read_templates( + ["some_filename.html"], "a_nonexistent_directory" + ) -- cgit 1.5.1 From 050e20e7ca56c3a5985fdcf64012800c153260f2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Aug 2020 12:18:01 -0400 Subject: Convert some of the general database methods to async (#8100) --- changelog.d/8100.misc | 1 + synapse/storage/database.py | 23 ++++++++----------- synapse/storage/databases/main/appservice.py | 2 +- synapse/storage/databases/main/events_worker.py | 16 +++++++------ synapse/storage/databases/main/registration.py | 8 +++---- synapse/storage/databases/main/roommember.py | 4 ++-- tests/handlers/test_profile.py | 4 ++-- tests/handlers/test_typing.py | 2 +- tests/storage/test_appservice.py | 16 +++++++++---- tests/storage/test_base.py | 16 ++++++++----- tests/storage/test_event_push_actions.py | 30 +++++++++++++------------ tests/storage/test_main.py | 2 +- tests/storage/test_profile.py | 4 ++-- 13 files changed, 69 insertions(+), 59 deletions(-) create mode 100644 changelog.d/8100.misc (limited to 'tests') diff --git a/changelog.d/8100.misc b/changelog.d/8100.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8100.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4ada6f5563..8a9e06efcf 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -332,8 +332,7 @@ class DatabasePool(object): """ return self._db_pool.running - @defer.inlineCallbacks - def _check_safe_to_upsert(self): + async def _check_safe_to_upsert(self): """ Is it safe to use native UPSERT? @@ -342,7 +341,7 @@ class DatabasePool(object): If the background updates have not completed, wait 15 sec and check again. """ - updates = yield self.simple_select_list( + updates = await self.simple_select_list( "background_updates", keyvalues=None, retcols=["update_name"], @@ -614,8 +613,7 @@ class DatabasePool(object): # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. - @defer.inlineCallbacks - def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): """Executes an INSERT query on the named table. Args: @@ -631,7 +629,7 @@ class DatabasePool(object): `or_ignore` is True """ try: - yield self.runInteraction(desc, self.simple_insert_txn, table, values) + await self.runInteraction(desc, self.simple_insert_txn, table, values) except self.engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -684,8 +682,7 @@ class DatabasePool(object): txn.executemany(sql, vals) - @defer.inlineCallbacks - def simple_upsert( + async def simple_upsert( self, table, keyvalues, @@ -714,14 +711,14 @@ class DatabasePool(object): inserting lock (bool): True to lock the table when doing the upsert. Returns: - Deferred(None or bool): Native upserts always return None. Emulated + None or bool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ attempts = 0 while True: try: - result = yield self.runInteraction( + return await self.runInteraction( desc, self.simple_upsert_txn, table, @@ -730,7 +727,6 @@ class DatabasePool(object): insertion_values, lock=lock, ) - return result except self.engine.module.IntegrityError as e: attempts += 1 if attempts >= 5: @@ -1121,8 +1117,7 @@ class DatabasePool(object): return cls.cursor_to_dict(txn) - @defer.inlineCallbacks - def simple_select_many_batch( + async def simple_select_many_batch( self, table, column, @@ -1156,7 +1151,7 @@ class DatabasePool(object): it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) ] for chunk in chunks: - rows = yield self.runInteraction( + rows = await self.runInteraction( desc, self.simple_select_many_txn, table, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 5cf1a88399..02568a2391 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore( service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. Returns: - A Deferred which resolves when the state was set successfully. + An Awaitable which resolves when the state was set successfully. """ return self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5687448e3d..8c63a0dc4d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -847,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", + rows = yield defer.ensureDeferred( + self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) ) return {r["event_id"] for r in rows} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index de50fa6e94..068ad22b30 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,9 +17,7 @@ import logging import re -from typing import Dict, List, Optional - -from twisted.internet.defer import Deferred +from typing import Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore): id_server (str) Returns: - Deferred + Awaitable """ # We need to use an upsert, in case they user had already bound the # threepid @@ -1084,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Deferred: + ) -> Awaitable: """Record a mapping from an external user id to a mxid Args: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 1cc8c08ed0..161edbeccb 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -767,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) - def get_membership_from_event_ids( + async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] ) -> List[dict]: """Get user_id and membership of a set of event IDs. """ - return self.db_pool.simple_select_many_batch( + return await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d70e1fc608..b609b30d4a 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase): self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - yield self.store.create_profile(self.frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart)) self.handler = hs.get_profile_handler() self.hs = hs @@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_profile("caroline") + yield defer.ensureDeferred(self.store.create_profile("caroline")) yield self.store.set_profile_displayname("caroline", "Caroline") response = yield defer.ensureDeferred( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 64afd581bc..e01de158e5 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ([], 0) ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None - self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( None ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 98b74890d5..a425e66f37 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_down(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" @@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_multiple_up(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index efcaeef1e7..13bcac743a 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_insert( - table="tablename", values={"columname": "Value"} + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", values={"columname": "Value"} + ) ) self.mock_txn.execute.assert_called_with( @@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_insert( - table="tablename", - # Use OrderedDict() so we can assert on the SQL generated - values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", + # Use OrderedDict() so we can assert on the SQL generated + values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 857db071d4..238bad5b45 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.db_pool.simple_insert( - "events", - { - "stream_ordering": so, - "received_ts": ts, - "event_id": "event%i" % so, - "type": "", - "room_id": "", - "content": "", - "processed": True, - "outlier": False, - "topological_ordering": 0, - "depth": 0, - }, + return defer.ensureDeferred( + self.store.db_pool.simple_insert( + "events", + { + "stream_ordering": so, + "received_ts": ts, + "event_id": "event%i" % so, + "type": "", + "room_id": "", + "content": "", + "processed": True, + "outlier": False, + "topological_ordering": 0, + "depth": 0, + }, + ) ) # start with the base case where there are no events in the table diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index ab0df5ea93..fbf8af940a 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") - yield self.store.create_profile(self.user.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield self.store.set_profile_displayname(self.user.localpart, self.displayname) users, total = yield self.store.get_users_paginate( diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 9b6f7211ae..9d5b8aa47d 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_displayname(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") @@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_avatar_url(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield self.store.set_profile_avatar_url( self.u_frank.localpart, "http://my.site/here" -- cgit 1.5.1 From 2f4d60a5ba9ec60ab4f3384cbef20fe662b4349b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 08:49:59 -0400 Subject: Iteratively encode JSON responses to avoid blocking the reactor. (#8013) --- changelog.d/8013.feature | 1 + synapse/http/server.py | 97 +++++++++++++++++++++++++++--- synapse/python_dependencies.py | 2 +- synapse/rest/key/v2/remote_key_resource.py | 6 +- tests/test_server.py | 1 - 5 files changed, 94 insertions(+), 13 deletions(-) create mode 100644 changelog.d/8013.feature (limited to 'tests') diff --git a/changelog.d/8013.feature b/changelog.d/8013.feature new file mode 100644 index 0000000000..b1eaf1e78a --- /dev/null +++ b/changelog.d/8013.feature @@ -0,0 +1 @@ +Iteratively encode JSON to avoid blocking the reactor. diff --git a/synapse/http/server.py b/synapse/http/server.py index ffe6cfa09e..37fdf14405 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,12 +22,13 @@ import types import urllib from http import HTTPStatus from io import BytesIO -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Tuple, Union import jinja2 -from canonicaljson import encode_canonical_json, encode_pretty_printed_json +from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json +from zope.interface import implementer -from twisted.internet import defer +from twisted.internet import defer, interfaces from twisted.python import failure from twisted.web import resource from twisted.web.server import NOT_DONE_YET, Request @@ -499,6 +500,78 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect): pass +@implementer(interfaces.IPullProducer) +class _ByteProducer: + """ + Iteratively write bytes to the request. + """ + + # The minimum number of bytes for each chunk. Note that the last chunk will + # usually be smaller than this. + min_chunk_size = 1024 + + def __init__( + self, request: Request, iterator: Iterator[bytes], + ): + self._request = request + self._iterator = iterator + + def start(self) -> None: + self._request.registerProducer(self, False) + + def _send_data(self, data: List[bytes]) -> None: + """ + Send a list of strings as a response to the request. + """ + if not data: + return + self._request.write(b"".join(data)) + + def resumeProducing(self) -> None: + # We've stopped producing in the meantime (note that this might be + # re-entrant after calling write). + if not self._request: + return + + # Get the next chunk and write it to the request. + # + # The output of the JSON encoder is coalesced until min_chunk_size is + # reached. (This is because JSON encoders produce a very small output + # per iteration.) + # + # Note that buffer stores a list of bytes (instead of appending to + # bytes) to hopefully avoid many allocations. + buffer = [] + buffered_bytes = 0 + while buffered_bytes < self.min_chunk_size: + try: + data = next(self._iterator) + buffer.append(data) + buffered_bytes += len(data) + except StopIteration: + # The entire JSON object has been serialized, write any + # remaining data, finalize the producer and the request, and + # clean-up any references. + self._send_data(buffer) + self._request.unregisterProducer() + self._request.finish() + self.stopProducing() + return + + self._send_data(buffer) + + def stopProducing(self) -> None: + self._request = None + + +def _encode_json_bytes(json_object: Any) -> Iterator[bytes]: + """ + Encode an object into JSON. Returns an iterator of bytes. + """ + for chunk in json_encoder.iterencode(json_object): + yield chunk.encode("utf-8") + + def respond_with_json( request: Request, code: int, @@ -533,15 +606,23 @@ def respond_with_json( return None if pretty_print: - json_bytes = encode_pretty_printed_json(json_object) + b"\n" + encoder = iterencode_pretty_printed_json else: if canonical_json or synapse.events.USE_FROZEN_DICTS: - # canonicaljson already encodes to bytes - json_bytes = encode_canonical_json(json_object) + encoder = iterencode_canonical_json else: - json_bytes = json_encoder.encode(json_object).encode("utf-8") + encoder = _encode_json_bytes + + request.setResponseCode(code) + request.setHeader(b"Content-Type", b"application/json") + request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") - return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors) + if send_cors: + set_cors_headers(request) + + producer = _ByteProducer(request, encoder(json_object)) + producer.start() + return NOT_DONE_YET def respond_with_json_bytes( diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 3250d41dde..dd77a44b8d 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -43,7 +43,7 @@ REQUIREMENTS = [ "jsonschema>=2.5.1", "frozendict>=1", "unpaddedbase64>=1.1.0", - "canonicaljson>=1.2.0", + "canonicaljson>=1.3.0", # we use the type definitions added in signedjson 1.1. "signedjson>=1.1.0", "pynacl>=1.2.1", diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 9b3f85b306..e266204f95 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,12 +15,12 @@ import logging from typing import Dict, Set -from canonicaljson import encode_canonical_json, json +from canonicaljson import json from signedjson.sign import sign_json from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes +from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request logger = logging.getLogger(__name__) @@ -223,4 +223,4 @@ class RemoteKey(DirectServeJsonResource): results = {"server_keys": signed_keys} - respond_with_json_bytes(request, 200, encode_canonical_json(results)) + respond_with_json(request, 200, results, canonical_json=True) diff --git a/tests/test_server.py b/tests/test_server.py index d628070e48..655c918a15 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -178,7 +178,6 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) - self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"]) class OptionsResourceTests(unittest.TestCase): -- cgit 1.5.1 From 5cf7c1299541d4b5ca5b3ac547a300a87465c7e5 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 18 Aug 2020 14:14:27 +0100 Subject: Remove : from allowed client_secret chars (#8101) Closes: https://github.com/matrix-org/synapse/issues/6766 Equivalent Sydent PR: https://github.com/matrix-org/sydent/pull/309 I believe it's now time to remove the extra allowed `:` from `client_secret` parameters. --- CHANGES.md | 14 ++++++++++++++ changelog.d/8101.bugfix | 1 + synapse/util/stringutils.py | 4 +--- tests/util/test_stringutils.py | 3 --- 4 files changed, 16 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8101.bugfix (limited to 'tests') diff --git a/CHANGES.md b/CHANGES.md index d4cc179489..c1b8673c04 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,17 @@ +For the next release +==================== + +Removal warning +--------------- + +Some older clients used a +[disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken) +(`:`) in the `client_secret` parameter of various endpoints. The incorrect +behaviour was allowed for backwards compatibility, but is now being removed +from Synapse as most users have updated their client. Further context can be +found at [\#6766](https://github.com/matrix-org/synapse/issues/6766). + + Synapse 1.19.0 (2020-08-17) =========================== diff --git a/changelog.d/8101.bugfix b/changelog.d/8101.bugfix new file mode 100644 index 0000000000..703bba4234 --- /dev/null +++ b/changelog.d/8101.bugfix @@ -0,0 +1 @@ +Synapse now correctly enforces the valid characters in the `client_secret` parameter used in various endpoints. diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 2e2b40a426..61d96a6c28 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -24,9 +24,7 @@ from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken -# Note: The : character is allowed here for older clients, but will be removed in a -# future release. Context: https://github.com/matrix-org/synapse/issues/6766 -client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$") +client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py index 4f4da29a98..8491f7cc83 100644 --- a/tests/util/test_stringutils.py +++ b/tests/util/test_stringutils.py @@ -28,9 +28,6 @@ class StringUtilsTestCase(unittest.TestCase): "_--something==_", "...--==-18913", "8Dj2odd-e9asd.cd==_--ddas-secret-", - # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766 - # To be removed in a future release - "SECRET:1234567890", ] bad = [ -- cgit 1.5.1 From f40645e60b9cab69c953094848be61c0989a91cb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 16:20:49 -0400 Subject: Convert events worker database to async/await. (#8071) --- changelog.d/8071.misc | 1 + synapse/event_auth.py | 2 +- synapse/handlers/federation.py | 16 +-- synapse/handlers/message.py | 6 +- synapse/handlers/room_member.py | 2 +- synapse/spam_checker_api/__init__.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 30 +++-- synapse/storage/databases/main/events_worker.py | 132 ++++++++++++--------- synapse/storage/databases/main/stream.py | 1 - .../test_resource_limits_server_notices.py | 6 +- tests/storage/test_appservice.py | 3 +- 12 files changed, 106 insertions(+), 97 deletions(-) create mode 100644 changelog.d/8071.misc (limited to 'tests') diff --git a/changelog.d/8071.misc b/changelog.d/8071.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8071.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c0981eee62..8c907ad596 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -47,7 +47,7 @@ def check( Args: room_version_obj: the version of the room event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. + auth_events: the existing room state. Raises: AuthError if the checks fail diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 593932adb7..5b270228e7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups(room_id, [event_id]) @@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler): async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler): auth_types = auth_types_for_event(event) current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] - current_auth_events = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids) current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() + (e.type, e.state_key): e for e in auth_events_map.values() } try: @@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 532fc30681..b999d91d1a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -960,7 +960,7 @@ class EventCreationHandler(object): allow_none=True, ) - is_admin_redaction = ( + is_admin_redaction = bool( original_event and event.sender != original_event.sender ) @@ -1080,8 +1080,8 @@ class EventCreationHandler(object): auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 31705cdbdb..aa1ccde211 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -716,7 +716,7 @@ class RoomMemberHandler(object): guest_access = await self.store.get_event(guest_access_id) - return ( + return bool( guest_access and guest_access.content and "guest_access" in guest_access.content diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 9b78924d96..4d9b13ac04 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -51,5 +51,5 @@ class SpamCheckerApi(object): state_ids = yield self._store.get_filtered_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) - state = yield self._store.get_events(state_ids.values()) + state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a1d3884667..dba8d91eef 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -641,7 +641,7 @@ class StateResolutionStore(object): allow_rejected (bool): If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. """ return self.store.get_events( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 431bd76693..4826be630c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. Args: @@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: list of events """ - return self.get_auth_chain_ids( + event_ids = await self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) + ) + return await self.get_events_as_list(event_ids) def get_auth_chain_ids( self, @@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list (list) limit (int) """ - return ( - self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, ) + events = await self.get_events_as_list(event_ids) + return sorted(events, key=lambda e: -e.depth) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) @@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = await self.get_events_as_list(ids) - return events + return await self.get_events_as_list(ids) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8c63a0dc4d..e3a154a527 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,9 +19,10 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names +from typing_extensions import Literal from twisted.internet import defer @@ -32,7 +33,7 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersions, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process @@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks +from synapse.types import Collection, get_domain_from_id +from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - @defer.inlineCallbacks - def get_event( + # Inform mypy that if allow_none is False (the default) then get_event + # always returns an EventBase. + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, + ) -> EventBase: + ... + + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, + ) -> Optional[EventBase]: + ... + + async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = False, allow_none: bool = False, check_room_id: Optional[str] = None, - ): + ) -> Optional[EventBase]: """Get an event from the database by event_id. Args: @@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - Deferred[EventBase|None] + The event, or None if the event was not found. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [event_id], redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore): return event - @defer.inlineCallbacks - def get_events( + async def get_events( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> Dict[str, EventBase]: """Get events from the database Args: @@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore): omits rejeted events from the response. Returns: - Deferred : Dict from event_id to event. + A mapping from event_id to event. """ - events = yield self.get_events_as_list( + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} - @defer.inlineCallbacks - def get_events_as_list( + async def get_events_as_list( self, - event_ids: List[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> List[EventBase]: """Get events from the database and return in a list in the same order as given by `event_ids` arg. @@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore): omits rejected events from the response. Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. + List of events fetched from the database. The events are in the same + order as `event_ids` arg. Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. @@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( + event_entry_map = await self._get_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore): if get_prev_content: if "replaces_state" in event.unsigned: - prev = yield self.get_event( + prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, @@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore): return events - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result """ event_entry_map = self._get_events_from_cache( @@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = yield self._get_events_from_db( + missing_events = await self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the database. Returned events will be added to the cache for future lookups. @@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which weren't asked for. """ @@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore): events_to_fetch = event_ids while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids = set() @@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - @defer.inlineCallbacks - def _enqueue_events(self, events): + async def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore): events (Iterable[str]): events to be fetched. Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. + Dict[str, Dict]: map from event id to row data from the database. May contain events that weren't requested. """ @@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore): logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - row_map = yield events_d + row_map = await events_d logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) return row_map @@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield defer.ensureDeferred( - self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) + rows = await self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", ) return {r["event_id"] for r in rows} - @defer.inlineCallbacks - def have_seen_events(self, event_ids): + async def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. Args: event_ids (iterable[str]): Returns: - Deferred[set[str]]: The events we have already seen. + set[str]: The events we have already seen. """ results = set() @@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results @@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore): room_id (str) Returns: - Deferred[dict[str:int]] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ - state_events = yield self.get_current_state_event_counts(room_id) + state_events = await self.get_current_state_event_counts(room_id) # Call this one "v1", so we can introduce new ones as we want to develop # it. @@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore): to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db_pool.simple_select_one( + @cached(max_entries=5000) + async def get_event_ordering(self, event_id): + res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 4377bddb8c..497f607703 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): limit: int = 0, order: str = "DESC", ) -> Tuple[List[EventBase], str]: - """Get new room events in stream ordering since `from_key`. Args: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 2858d13558..23db821fb7 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index a425e66f37..17fbde284a 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) -- cgit 1.5.1 From 76d21d14a042756b0c8a8f520dfd9ea09cf092c7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 19 Aug 2020 10:39:31 +0100 Subject: Separate `get_current_token` into two. (#8113) The function is used for two purposes: 1) for subscribers of streams to get a token they can use to get further updates with, and 2) for replication to track position of the writers of the stream. For streams with a single writer the two scenarios produce the same result, however the situation becomes complicated for streams with multiple writers. The current `MultiWriterIdGenerator` does not correctly handle the first case (which is not an issue as its only used for the `caches` stream which nothing subscribes to outside of replication). --- changelog.d/8113.misc | 1 + .../slave/storage/_slaved_id_tracker.py | 8 +++++ synapse/replication/tcp/streams/_base.py | 2 +- synapse/storage/databases/main/cache.py | 4 +-- synapse/storage/util/id_generators.py | 36 ++++++++++++++++------ tests/storage/test_id_generators.py | 16 +++++----- 6 files changed, 47 insertions(+), 20 deletions(-) create mode 100644 changelog.d/8113.misc (limited to 'tests') diff --git a/changelog.d/8113.misc b/changelog.d/8113.misc new file mode 100644 index 0000000000..00bec4f8ef --- /dev/null +++ b/changelog.d/8113.misc @@ -0,0 +1 @@ +Separate `get_current_token` into two since there are two different use cases for it. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 9d1d173b2f..d43eaf3a29 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -33,3 +33,11 @@ class SlavedIdTracker(object): int """ return self._current + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 7a42de3f7d..1e92d52165 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -405,7 +405,7 @@ class CachesStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_cache_stream_token, + store.get_cache_stream_token_for_writer, store.get_all_updated_caches, ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 10de446065..1e7637a6f5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): }, ) - def get_cache_stream_token(self, instance_name): + def get_cache_stream_token_for_writer(self, instance_name: str) -> int: if self._cache_id_gen: - return self._cache_id_gen.get_current_token(instance_name) + return self._cache_id_gen.get_current_token_for_writer(instance_name) else: return 0 diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index e2ddd01290..8276a755e5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -158,6 +158,14 @@ class StreamIdGenerator(object): return self._current + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() + class ChainedIdGenerator(object): """Used to generate new stream ids where the stream must be kept in sync @@ -216,6 +224,14 @@ class ChainedIdGenerator(object): "Attempted to advance token on source for table %r", self._table ) + def get_current_token_for_writer(self, instance_name: str) -> Tuple[int, int]: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() + class MultiWriterIdGenerator: """An ID generator that tracks a stream that can have multiple writers. @@ -298,7 +314,7 @@ class MultiWriterIdGenerator: # Assert the fetched ID is actually greater than what we currently # believe the ID to be. If not, then the sequence and table have got # out of sync somehow. - assert self.get_current_token() < next_id + assert self.get_current_token_for_writer(self._instance_name) < next_id with self._lock: self._unfinished_ids.add(next_id) @@ -344,16 +360,18 @@ class MultiWriterIdGenerator: curr = self._current_positions.get(self._instance_name, 0) self._current_positions[self._instance_name] = max(curr, next_id) - def get_current_token(self, instance_name: str = None) -> int: - """Gets the current position of a named writer (defaults to current - instance). - - Returns 0 if we don't have a position for the named writer (likely due - to it being a new writer). + def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. """ - if instance_name is None: - instance_name = self._instance_name + # Currently we don't support this operation, as it's not obvious how to + # condense the stream positions of multiple writers into a single int. + raise NotImplementedError() + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + """ with self._lock: return self._current_positions.get(instance_name, 0) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index e845410dae..7a05194653 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator() self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.get_success(_get_next_async()) self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token("master"), 8) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) def test_multi_instance(self): """Test that reads and writes from multiple processes are handled @@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen = self._create_id_generator("second") self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(first_id_gen.get_current_token("first"), 3) - self.assertEqual(first_id_gen.get_current_token("second"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator() self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -176,9 +176,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token("master"), 8) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) -- cgit 1.5.1 From d294f0e7e18f4ba91b5ca6a944758d5b92d1ea2a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 19 Aug 2020 07:09:07 -0400 Subject: Remove the unused inlineCallbacks code-paths in the caching code (#8119) --- changelog.d/8119.misc | 1 + synapse/util/caches/descriptors.py | 54 ++++++----------------------------- tests/util/caches/test_descriptors.py | 12 ++++---- 3 files changed, 15 insertions(+), 52 deletions(-) create mode 100644 changelog.d/8119.misc (limited to 'tests') diff --git a/changelog.d/8119.misc b/changelog.d/8119.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8119.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index c2d72a82cf..49d9fddcf0 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -285,16 +285,9 @@ class Cache(object): class _CacheDescriptorBase(object): - def __init__( - self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False - ): + def __init__(self, orig: _CachedFunction, num_args, cache_context=False): self.orig = orig - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args @@ -364,7 +357,7 @@ class CacheDescriptor(_CacheDescriptorBase): invalidated) by adding a special "cache_context" argument to the function and passing that as a kwarg to all caches called. For example:: - @cachedInlineCallbacks(cache_context=True) + @cached(cache_context=True) def foo(self, key, cache_context): r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) @@ -382,17 +375,11 @@ class CacheDescriptor(_CacheDescriptorBase): max_entries=1000, num_args=None, tree=False, - inlineCallbacks=False, cache_context=False, iterable=False, ): - super(CacheDescriptor, self).__init__( - orig, - num_args=num_args, - inlineCallbacks=inlineCallbacks, - cache_context=cache_context, - ) + super().__init__(orig, num_args=num_args, cache_context=cache_context) self.max_entries = max_entries self.tree = tree @@ -465,9 +452,7 @@ class CacheDescriptor(_CacheDescriptorBase): observer = defer.succeed(cached_result_d) except KeyError: - ret = defer.maybeDeferred( - preserve_fn(self.function_to_call), obj, *args, **kwargs - ) + ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) def onErr(f): cache.invalidate(cache_key) @@ -510,9 +495,7 @@ class CacheListDescriptor(_CacheDescriptorBase): of results. """ - def __init__( - self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False - ): + def __init__(self, orig, cached_method_name, list_name, num_args=None): """ Args: orig (function) @@ -521,12 +504,8 @@ class CacheListDescriptor(_CacheDescriptorBase): num_args (int): number of positional arguments (excluding ``self``, but including list_name) to use as cache keys. Defaults to all named args of the function. - inlineCallbacks (bool): Whether orig is a generator that should - be wrapped by defer.inlineCallbacks """ - super(CacheListDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks - ) + super().__init__(orig, num_args=num_args) self.list_name = list_name @@ -631,7 +610,7 @@ class CacheListDescriptor(_CacheDescriptorBase): cached_defers.append( defer.maybeDeferred( - preserve_fn(self.function_to_call), **args_to_call + preserve_fn(self.orig), **args_to_call ).addCallbacks(complete_all, errback) ) @@ -695,21 +674,7 @@ def cached( ) -def cachedInlineCallbacks( - max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False -): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - tree=tree, - inlineCallbacks=True, - cache_context=cache_context, - iterable=iterable, - ) - - -def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=None): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -725,8 +690,6 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal do batch lookups in the cache. num_args (int): Number of arguments to use as the key in the cache (including list_name). Defaults to all named parameters. - inlineCallbacks (bool): Should the function be wrapped in an - `defer.inlineCallbacks`? Example: @@ -744,5 +707,4 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, - inlineCallbacks=inlineCallbacks, ) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 4d2b9e0d64..0363735d4f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1", inlineCallbacks=True) - def list_fn(self, args1, arg2): + @descriptors.cachedList("fn", "args1") + async def list_fn(self, args1, arg2): assert current_context().request == "c1" # we want this to behave like an asynchronous function - yield run_on_reactor() + await run_on_reactor() assert current_context().request == "c1" return self.mock(args1, arg2) @@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1", inlineCallbacks=True) - def list_fn(self, args1, arg2): + @descriptors.cachedList("fn", "args1") + async def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function - yield run_on_reactor() + await run_on_reactor() return self.mock(args1, arg2) obj = Cls() -- cgit 1.5.1 From f594e434c35ab99bc71216cbb06082aa2b975980 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 19 Aug 2020 08:07:57 -0400 Subject: Switch the JSON byte producer from a pull to a push producer. (#8116) --- changelog.d/8116.feature | 1 + synapse/http/server.py | 75 +++++++++++++++++------------ tests/rest/client/v1/test_login.py | 16 ++---- tests/rest/client/v2_alpha/test_register.py | 4 +- tests/storage/test_cleanup_extrems.py | 3 +- 5 files changed, 53 insertions(+), 46 deletions(-) create mode 100644 changelog.d/8116.feature (limited to 'tests') diff --git a/changelog.d/8116.feature b/changelog.d/8116.feature new file mode 100644 index 0000000000..b1eaf1e78a --- /dev/null +++ b/changelog.d/8116.feature @@ -0,0 +1 @@ +Iteratively encode JSON to avoid blocking the reactor. diff --git a/synapse/http/server.py b/synapse/http/server.py index 37fdf14405..8d791bd2ca 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -500,7 +500,7 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect): pass -@implementer(interfaces.IPullProducer) +@implementer(interfaces.IPushProducer) class _ByteProducer: """ Iteratively write bytes to the request. @@ -515,52 +515,64 @@ class _ByteProducer: ): self._request = request self._iterator = iterator + self._paused = False - def start(self) -> None: - self._request.registerProducer(self, False) + # Register the producer and start producing data. + self._request.registerProducer(self, True) + self.resumeProducing() def _send_data(self, data: List[bytes]) -> None: """ - Send a list of strings as a response to the request. + Send a list of bytes as a chunk of a response. """ if not data: return self._request.write(b"".join(data)) + def pauseProducing(self) -> None: + self._paused = True + def resumeProducing(self) -> None: # We've stopped producing in the meantime (note that this might be # re-entrant after calling write). if not self._request: return - # Get the next chunk and write it to the request. - # - # The output of the JSON encoder is coalesced until min_chunk_size is - # reached. (This is because JSON encoders produce a very small output - # per iteration.) - # - # Note that buffer stores a list of bytes (instead of appending to - # bytes) to hopefully avoid many allocations. - buffer = [] - buffered_bytes = 0 - while buffered_bytes < self.min_chunk_size: - try: - data = next(self._iterator) - buffer.append(data) - buffered_bytes += len(data) - except StopIteration: - # The entire JSON object has been serialized, write any - # remaining data, finalize the producer and the request, and - # clean-up any references. - self._send_data(buffer) - self._request.unregisterProducer() - self._request.finish() - self.stopProducing() - return - - self._send_data(buffer) + self._paused = False + + # Write until there's backpressure telling us to stop. + while not self._paused: + # Get the next chunk and write it to the request. + # + # The output of the JSON encoder is buffered and coalesced until + # min_chunk_size is reached. This is because JSON encoders produce + # very small output per iteration and the Request object converts + # each call to write() to a separate chunk. Without this there would + # be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n"). + # + # Note that buffer stores a list of bytes (instead of appending to + # bytes) to hopefully avoid many allocations. + buffer = [] + buffered_bytes = 0 + while buffered_bytes < self.min_chunk_size: + try: + data = next(self._iterator) + buffer.append(data) + buffered_bytes += len(data) + except StopIteration: + # The entire JSON object has been serialized, write any + # remaining data, finalize the producer and the request, and + # clean-up any references. + self._send_data(buffer) + self._request.unregisterProducer() + self._request.finish() + self.stopProducing() + return + + self._send_data(buffer) def stopProducing(self) -> None: + # Clear a circular reference. self._request = None @@ -620,8 +632,7 @@ def respond_with_json( if send_cors: set_cors_headers(request) - producer = _ByteProducer(request, encoder(json_object)) - producer.start() + _ByteProducer(request, encoder(json_object)) return NOT_DONE_YET diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index db52725cfe..2668662c9e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) @@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) @@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 53a43038f0..2fc3a60fc5 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 8e9a650f9f..43639ca286 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ "3" ] = 300000 + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() # All entries within time frame self.assertEqual( @@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): 3, ) # Oldest room to expire - self.pump(1) + self.pump(1.01) self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() self.assertEqual( len( -- cgit 1.5.1 From e259d63f73fd7599520d0c4a6f5082e5cd383d25 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 15:07:42 -0400 Subject: Stop shadow-banned users from sending invites. (#8095) --- changelog.d/8095.feature | 1 + synapse/api/errors.py | 8 +++ synapse/handlers/room.py | 16 +++++- synapse/handlers/room_member.py | 62 ++++++++++++++++++++++- synapse/rest/admin/rooms.py | 3 ++ synapse/rest/client/v1/room.py | 67 +++++++++++++++---------- tests/rest/client/v1/test_rooms.py | 100 +++++++++++++++++++++++++++++++++++++ 7 files changed, 226 insertions(+), 31 deletions(-) create mode 100644 changelog.d/8095.feature (limited to 'tests') diff --git a/changelog.d/8095.feature b/changelog.d/8095.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8095.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/api/errors.py b/synapse/api/errors.py index a3f314118a..4888c0ec4d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -604,3 +604,11 @@ class HttpResponseException(CodeMessageException): errmsg = j.pop("error", self.msg) return ProxiedRequestError(self.code, errmsg, errcode, j) + + +class ShadowBanError(Exception): + """ + Raised when a shadow-banned user attempts to perform an action. + + This should be caught and a proper "fake" success response sent to the user. + """ diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 442cca28e6..0fc71475c3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,6 +20,7 @@ import itertools import logging import math +import random import string from collections import OrderedDict from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple @@ -626,6 +627,7 @@ class RoomCreationHandler(BaseHandler): if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) + invite_3pid_list = config.get("invite_3pid", []) invite_list = config.get("invite", []) for i in invite_list: try: @@ -634,6 +636,14 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) + if (invite_list or invite_3pid_list) and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + + # Allow the request to go through, but remove any associated invites. + invite_3pid_list = [] + invite_list = [] + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") @@ -648,8 +658,6 @@ class RoomCreationHandler(BaseHandler): % (user_id,), ) - invite_3pid_list = config.get("invite_3pid", []) - visibility = config.get("visibility", None) is_public = visibility == "public" @@ -744,6 +752,8 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct + # Note that update_membership with an action of "invite" can raise a + # ShadowBanError, but this was handled above by emptying invite_list. _, last_stream_id = await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), @@ -758,6 +768,8 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] + # Note that do_3pid_invite can raise a ShadowBanError, but this was + # handled above by emptying invite_3pid_list. last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index aa1ccde211..3a6ee6378d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -15,6 +15,7 @@ import abc import logging +import random from http import HTTPStatus from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union @@ -22,7 +23,13 @@ from unpaddedbase64 import encode_base64 from synapse import types from synapse.api.constants import MAX_DEPTH, EventTypes, Membership -from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + LimitExceededError, + ShadowBanError, + SynapseError, +) from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import EventFormatVersions from synapse.crypto.event_signing import compute_event_reference_hash @@ -285,6 +292,31 @@ class RoomMemberHandler(object): content: Optional[dict] = None, require_consent: bool = True, ) -> Tuple[str, int]: + """Update a user's membership in a room. + + Params: + requester: The user who is performing the update. + target: The user whose membership is being updated. + room_id: The room ID whose membership is being updated. + action: The membership change, see synapse.api.constants.Membership. + txn_id: The transaction ID, if given. + remote_room_hosts: Remote servers to send the update to. + third_party_signed: Information from a 3PID invite. + ratelimit: Whether to rate limit the request. + content: The content of the created event. + require_consent: Whether consent is required. + + Returns: + A tuple of the new event ID and stream ID. + + Raises: + ShadowBanError if a shadow-banned requester attempts to send an invite. + """ + if action == Membership.INVITE and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -773,6 +805,25 @@ class RoomMemberHandler(object): txn_id: Optional[str], id_access_token: Optional[str] = None, ) -> int: + """Invite a 3PID to a room. + + Args: + room_id: The room to invite the 3PID to. + inviter: The user sending the invite. + medium: The 3PID's medium. + address: The 3PID's address. + id_server: The identity server to use. + requester: The user making the request. + txn_id: The transaction ID this is part of, or None if this is not + part of a transaction. + id_access_token: The optional identity server access token. + + Returns: + The new stream ID. + + Raises: + ShadowBanError if the requester has been shadow-banned. + """ if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -780,6 +831,11 @@ class RoomMemberHandler(object): 403, "Invites have been disabled on this server", Codes.FORBIDDEN ) + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. await self.base_handler.ratelimit(requester) @@ -804,6 +860,8 @@ class RoomMemberHandler(object): ) if invitee: + # Note that update_membership with an action of "invite" can raise + # a ShadowBanError, but this was done above already. _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) @@ -1042,7 +1100,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return event_id, stream_id # The room is too large. Leave. - requester = types.create_requester(user, None, False, None) + requester = types.create_requester(user, None, False, False, None) await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 7c292ef3f9..09726d52d6 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -316,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet): join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): + # update_membership with an action of "invite" can raise a + # ShadowBanError. This is not handled since it is assumed that + # an admin isn't going to call this API with a shadow-banned user. await self.room_member_handler.update_membership( requester=requester, target=fake_requester.user, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f216382636..a9dd3a6aec 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -27,6 +27,7 @@ from synapse.api.errors import ( Codes, HttpResponseException, InvalidClientCredentialsError, + ShadowBanError, SynapseError, ) from synapse.api.filtering import Filter @@ -45,6 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder +from synapse.util.stringutils import random_string MYPY = False if MYPY: @@ -200,14 +202,17 @@ class RoomStateEventRestServlet(TransactionRestServlet): event_dict["state_key"] = state_key if event_type == EventTypes.Member: - membership = content.get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - requester, - target=UserID.from_string(state_key), - room_id=room_id, - action=membership, - content=content, - ) + try: + membership = content.get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + requester, + target=UserID.from_string(state_key), + room_id=room_id, + action=membership, + content=content, + ) + except ShadowBanError: + event_id = "$" + random_string(43) else: ( event, @@ -719,16 +724,20 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - await self.room_member_handler.do_3pid_invite( - room_id, - requester.user, - content["medium"], - content["address"], - content["id_server"], - requester, - txn_id, - content.get("id_access_token"), - ) + try: + await self.room_member_handler.do_3pid_invite( + room_id, + requester.user, + content["medium"], + content["address"], + content["id_server"], + requester, + txn_id, + content.get("id_access_token"), + ) + except ShadowBanError: + # Pretend the request succeeded. + pass return 200, {} target = requester.user @@ -740,15 +749,19 @@ class RoomMembershipRestServlet(TransactionRestServlet): if "reason" in content: event_content = {"reason": content["reason"]} - await self.room_member_handler.update_membership( - requester=requester, - target=target, - room_id=room_id, - action=membership_action, - txn_id=txn_id, - third_party_signed=content.get("third_party_signed", None), - content=event_content, - ) + try: + await self.room_member_handler.update_membership( + requester=requester, + target=target, + room_id=room_id, + action=membership_action, + txn_id=txn_id, + third_party_signed=content.get("third_party_signed", None), + content=event_content, + ) + except ShadowBanError: + # Pretend the request succeeded. + pass return_value = {} diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ef6b775ed2..e674eb90d7 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1974,3 +1974,103 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) + + +class ShadowBannedTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.banned_user_id = self.register_user("banned", "test") + self.banned_access_token = self.login("banned", "test") + + self.store = self.hs.get_datastore() + + self.get_success( + self.store.db_pool.simple_update( + table="users", + keyvalues={"name": self.banned_user_id}, + updatevalues={"shadow_banned": True}, + desc="shadow_ban", + ) + ) + + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + def test_invite(self): + """Invites from shadow-banned users don't actually get sent.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + self.helper.invite( + room=room_id, + src=self.banned_user_id, + tok=self.banned_access_token, + targ=self.other_user_id, + ) + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + def test_invite_3pid(self): + """Ensure that a 3PID invite does not attempt to contact the identity server.""" + identity_handler = self.hs.get_handlers().identity_handler + identity_handler.lookup_3pid = Mock( + side_effect=AssertionError("This should not get called") + ) + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + request, channel = self.make_request( + "POST", + "/rooms/%s/invite" % (room_id,), + {"id_server": "test", "medium": "email", "address": "test@test.test"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + # This should have raised an error earlier, but double check this wasn't called. + identity_handler.lookup_3pid.assert_not_called() + + def test_create_room(self): + """Invitations during a room creation should be discarded, but the room still gets created.""" + # The room creation is successful. + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/createRoom", + {"visibility": "public", "invite": [self.other_user_id]}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + room_id = channel.json_body["room_id"] + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + # Since a real room was created, the other user should be able to join it. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + # Both users should be in the room. + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) -- cgit 1.5.1 From 3f91638da6ea0aeaf789ddc8ca1e624a11b7ebb2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 15:42:58 -0400 Subject: Allow denying or shadow banning registrations via the spam checker (#8034) --- changelog.d/8034.feature | 1 + synapse/events/spamcheck.py | 35 ++++++++++++++- synapse/handlers/auth.py | 8 ++++ synapse/handlers/cas_handler.py | 11 ++++- synapse/handlers/oidc_handler.py | 21 +++++++-- synapse/handlers/register.py | 26 ++++++++++- synapse/handlers/saml_handler.py | 18 +++++++- synapse/rest/client/v2_alpha/register.py | 5 +++ synapse/spam_checker_api/__init__.py | 11 +++++ .../main/schema/delta/58/07persist_ui_auth_ips.sql | 25 +++++++++++ synapse/storage/databases/main/ui_auth.py | 39 +++++++++++++++- tests/handlers/test_oidc.py | 18 ++++++-- tests/handlers/test_register.py | 52 +++++++++++++++++++++- tests/handlers/test_user_directory.py | 6 +-- 14 files changed, 258 insertions(+), 18 deletions(-) create mode 100644 changelog.d/8034.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql (limited to 'tests') diff --git a/changelog.d/8034.feature b/changelog.d/8034.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8034.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 1ffc9525d1..a7cddac974 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,9 +15,10 @@ # limitations under the License. import inspect -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple -from synapse.spam_checker_api import SpamCheckerApi +from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi +from synapse.types import Collection MYPY = False if MYPY: @@ -160,3 +161,33 @@ class SpamChecker(object): return True return False + + def check_registration_for_spam( + self, + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: + """Checks if we should allow the given registration request. + + Args: + email_threepid: The email threepid used for registering, if any + username: The request user name, if any + request_info: List of tuples of user agent and IP that + were used during the registration process. + + Returns: + Enum for how the request should be handled + """ + + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_registration_for_spam", None) + if checker: + behaviour = checker(email_threepid, username, request_info) + assert isinstance(behaviour, RegistrationBehaviour) + if behaviour != RegistrationBehaviour.ALLOW: + return behaviour + + return RegistrationBehaviour.ALLOW diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 68d6870e40..654f58ddae 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -364,6 +364,14 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + + await self.store.add_user_agent_ip_to_ui_auth_session( + session.session_id, user_agent, clientip + ) + if not authdict: raise InteractiveAuthIncompleteError( session.session_id, self._auth_dict_for_flows(flows, session.session_id) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 786e608fa2..a4cc4b9a5a 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -35,6 +35,7 @@ class CasHandler: """ def __init__(self, hs): + self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -210,8 +211,16 @@ class CasHandler: else: if not registered_user_id: + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders( + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name + localpart=localpart, + default_display_name=user_display_name, + user_agent_ips=(user_agent, ip_address), ) await self._auth_handler.complete_sso_login( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index dd3703cbd2..c5bd2fea68 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -93,6 +93,7 @@ class OidcHandler: """ def __init__(self, hs: "HomeServer"): + self.hs = hs self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._client_auth = ClientAuth( @@ -689,9 +690,17 @@ class OidcHandler: self._render_error(request, "invalid_token", str(e)) return + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user(userinfo, token) + user_id = await self._map_userinfo_to_user( + userinfo, token, user_agent, ip_address + ) except MappingException as e: logger.exception("Could not map user") self._render_error(request, "mapping_error", str(e)) @@ -828,7 +837,9 @@ class OidcHandler: now = self._clock.time_msec() return now < expiry - async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: + async def _map_userinfo_to_user( + self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str + ) -> str: """Maps a UserInfo object to a mxid. UserInfo should have a claim that uniquely identifies users. This claim @@ -843,6 +854,8 @@ class OidcHandler: Args: userinfo: an object representing the user token: a dict with the tokens obtained from the provider + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Raises: MappingException: if there was an error while mapping some properties @@ -899,7 +912,9 @@ class OidcHandler: # It's the first time this user is logging in and the mapped mxid was # not taken, register the user registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=attributes["display_name"], + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ccd96e4626..cde2dbca92 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -26,6 +26,7 @@ from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, ReplicationRegisterServlet, ) +from synapse.spam_checker_api import RegistrationBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester @@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._server_notices_mxid = hs.config.server_notices_mxid + self.spam_checker = hs.get_spam_checker() + if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( @@ -144,7 +147,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, - shadow_banned=False, + user_agent_ips=None, ): """Registers a new client on the server. @@ -162,7 +165,8 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. - shadow_banned (bool): Shadow-ban the created user. + user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + during the registration process. Returns: str: user_id Raises: @@ -170,6 +174,24 @@ class RegistrationHandler(BaseHandler): """ self.check_registration_ratelimit(address) + result = self.spam_checker.check_registration_for_spam( + threepid, localpart, user_agent_ips or [], + ) + + if result == RegistrationBehaviour.DENY: + logger.info( + "Blocked registration of %r", localpart, + ) + # We return a 429 to make it not obvious that they've been + # denied. + raise SynapseError(429, "Rate limited") + + shadow_banned = result == RegistrationBehaviour.SHADOW_BAN + if shadow_banned: + logger.info( + "Shadow banning registration of %r", localpart, + ) + # do not check_auth_blocking if the call is coming through the Admin API if not by_admin: await self.auth.check_auth_blocking(threepid=threepid) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index c1fcb98454..b426199aa6 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -54,6 +54,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "synapse.server.HomeServer"): + self.hs = hs self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -133,8 +134,14 @@ class SamlHandler: # the dict. self.expire_sessions() + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + user_id, current_session = await self._map_saml_response_to_user( - resp_bytes, relay_state + resp_bytes, relay_state, user_agent, ip_address ) # Complete the interactive auth session or the login. @@ -147,7 +154,11 @@ class SamlHandler: await self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _map_saml_response_to_user( - self, resp_bytes: str, client_redirect_url: str + self, + resp_bytes: str, + client_redirect_url: str, + user_agent: str, + ip_address: str, ) -> Tuple[str, Optional[Saml2SessionData]]: """ Given a sample response, retrieve the cached session and user for it. @@ -155,6 +166,8 @@ class SamlHandler: Args: resp_bytes: The SAML response. client_redirect_url: The redirect URL passed in by the client. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Returns: Tuple of the user ID and SAML session associated with this response. @@ -291,6 +304,7 @@ class SamlHandler: localpart=localpart, default_display_name=displayname, bind_emails=emails, + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 7290fd0756..be0e680ac5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -591,12 +591,17 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) + entries = await self.store.get_user_agents_ips_to_ui_auth_session( + session_id + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, + user_agent_ips=entries, ) # Necessary due to auth checks prior to the threepid being # written to the db diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 7f63f1bfa0..9be92e2565 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from enum import Enum from twisted.internet import defer @@ -25,6 +26,16 @@ if MYPY: logger = logging.getLogger(__name__) +class RegistrationBehaviour(Enum): + """ + Enum to define whether a registration request should allowed, denied, or shadow-banned. + """ + + ALLOW = "allow" + SHADOW_BAN = "shadow_ban" + DENY = "deny" + + class SpamCheckerApi(object): """A proxy object that gets passed to spam checkers so they can get access to rooms and other relevant information. diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql new file mode 100644 index 0000000000..4cc96a5341 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A table of the IP address and user-agent used to complete each step of a +-- user-interactive authentication session. +CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips( + session_id TEXT NOT NULL, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + UNIQUE (session_id, ip, user_agent), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 6281a41a3d..9eef8e57c5 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import attr @@ -260,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore): return serverdict.get(key, default) + async def add_user_agent_ip_to_ui_auth_session( + self, session_id: str, user_agent: str, ip: str, + ): + """Add the given user agent / IP to the tracking table + """ + await self.db_pool.simple_upsert( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, + values={}, + desc="add_user_agent_ip_to_ui_auth_session", + ) + + async def get_user_agents_ips_to_ui_auth_session( + self, session_id: str, + ) -> List[Tuple[str, str]]: + """Get the given user agents / IPs used during the ui auth process + + Returns: + List of user_agent/ip pairs + """ + rows = await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ) + return [(row["user_agent"], row["ip"]) for row in rows] + class UIAuthStore(UIAuthWorkerStore): def delete_old_ui_auth_sessions(self, expiration_time: int): @@ -285,6 +313,15 @@ class UIAuthStore(UIAuthWorkerStore): txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] + # Delete the corresponding IP/user agents. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_ips", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 1bb25ab684..f92f3b8c15 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock(spec=["args", "getCookie", "addCookie"]) + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) code = "code" state = "state" nonce = "nonce" client_redirect_url = "http://client/redirect" + user_agent = "Browser" + ip_address = "10.0.0.1" session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, @@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] + request.getClientIP.return_value = ip_address + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( @@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_not_called() self.handler._render_error.assert_not_called() @@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._render_error.assert_not_called() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e364b1bd62..5c92d0e8c9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,18 +17,21 @@ from mock import Mock from twisted.internet import defer +from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler +from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester from tests.test_utils import make_awaitable from tests.unittest import override_config +from tests.utils import mock_getRawHeaders from .. import unittest -class RegistrationHandlers(object): +class RegistrationHandlers: def __init__(self, hs): self.registration_handler = RegistrationHandler(hs) @@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) + def test_spam_checker_deny(self): + """A spam checker can deny registration, which results in an error.""" + + class DenyAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.DENY + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [DenyAll()] + + self.get_failure(self.handler.register_user(localpart="user"), SynapseError) + + def test_spam_checker_shadow_ban(self): + """A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" + + class BanAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.SHADOW_BAN + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanAll()] + + user_id = self.get_success(self.handler.register_user(localpart="user")) + + # Get an access token. + token = self.macaroon_generator.generate_access_token(user_id) + self.get_success( + self.store.add_access_token_to_user( + user_id=user_id, token=token, device_id=None, valid_until_ms=None + ) + ) + + # Ensure the user was marked as shadow-banned. + request = Mock(args={}) + request.args[b"access_token"] = [token.encode("ascii")] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + auth = Auth(self.hs) + requester = self.get_success(auth.get_user_by_req(request)) + + self.assertTrue(requester.shadow_banned) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 31ed89a5cd..87be94111f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_spam_checker(self): """ - A user which fails to the spam checks will not appear in search results. + A user which fails the spam checks will not appear in search results. """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") @@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that does not filter any users. spam_checker = self.hs.get_spam_checker() - class AllowAll(object): + class AllowAll: def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - class BlockAll(object): + class BlockAll: def check_username_for_spam(self, user_profile): # All users are spammy. return True -- cgit 1.5.1 From cbbf9126cbd2ace90c1c0f615b87bcec30fdcbd8 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Fri, 21 Aug 2020 15:07:56 +0100 Subject: Do not apply ratelimiting on joins to appservices (#8139) Add new method ratelimiter.can_requester_do_action and ensure that appservices are exempt from being ratelimited. Co-authored-by: Patrick Cloke Co-authored-by: Erik Johnston --- changelog.d/8139.bugfix | 1 + synapse/api/ratelimiting.py | 37 +++++++++++++++++++++ synapse/handlers/room_member.py | 14 ++++---- tests/api/test_ratelimiting.py | 73 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8139.bugfix (limited to 'tests') diff --git a/changelog.d/8139.bugfix b/changelog.d/8139.bugfix new file mode 100644 index 0000000000..21f65d87b7 --- /dev/null +++ b/changelog.d/8139.bugfix @@ -0,0 +1 @@ +Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index ec6b3a69a2..e62ae50ac2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.types import Requester from synapse.util import Clock @@ -43,6 +44,42 @@ class Ratelimiter(object): # * The rate_hz of this particular entry. This can vary per request self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] + def can_requester_do_action( + self, + requester: Requester, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: + """Can the requester perform the action? + + Args: + requester: The requester to key off when rate limiting. The user property + will be used. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Returns: + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero + """ + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + return self.can_do_action( + requester.user.to_string(), rate_hz, burst_count, update, _time_now_s + ) + def can_do_action( self, key: Any, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 3a6ee6378d..a03cb02792 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -491,9 +491,10 @@ class RoomMemberHandler(object): if is_host_in_room: time_now_s = self.clock.time() - allowed, time_allowed = self._join_rate_limiter_local.can_do_action( - requester.user.to_string(), - ) + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_local.can_requester_do_action(requester,) if not allowed: raise LimitExceededError( @@ -502,9 +503,10 @@ class RoomMemberHandler(object): else: time_now_s = self.clock.time() - allowed, time_allowed = self._join_rate_limiter_remote.can_do_action( - requester.user.to_string(), - ) + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_remote.can_requester_do_action(requester,) if not allowed: raise LimitExceededError( diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index d580e729c5..1e1f30d790 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,4 +1,6 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter +from synapse.appservice import ApplicationService +from synapse.types import create_requester from tests import unittest @@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase): self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) + def test_allowed_user_via_can_requester_do_action(self): + user_requester = create_requester("@user:example.com") + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=True, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=False, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + def test_allowed_via_ratelimit(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) -- cgit 1.5.1 From 3f49f74610197d32fe73678cabc10f08732e66b8 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 24 Aug 2020 11:33:55 +0100 Subject: Don't fail /submit_token requests on incorrect session ID if request_token_inhibit_3pid_errors is turned on (#7991) * Don't raise session_id errors on submit_token if request_token_inhibit_3pid_errors is set * Changelog * Also wait some time before responding to /requestToken * Incorporate review * Update synapse/storage/databases/main/registration.py Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> * Incorporate review Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/7991.misc | 1 + synapse/rest/client/v2_alpha/account.py | 10 +++++++++ synapse/rest/client/v2_alpha/register.py | 7 ++++++ synapse/storage/databases/main/registration.py | 25 ++++++++++++++++----- tests/storage/test_registration.py | 31 ++++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 changelog.d/7991.misc (limited to 'tests') diff --git a/changelog.d/7991.misc b/changelog.d/7991.misc new file mode 100644 index 0000000000..1562e3af9e --- /dev/null +++ b/changelog.d/7991.misc @@ -0,0 +1 @@ +Don't fail `/submit_token` requests on incorrect session ID if `request_token_inhibit_3pid_errors` is turned on. diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 203e76b9f2..3481477731 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random from http import HTTPStatus from synapse.api.constants import LoginType @@ -109,6 +110,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if self.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) @@ -448,6 +452,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): if self.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -516,6 +523,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index be0e680ac5..51372cdb5e 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,6 +16,7 @@ import hmac import logging +import random from typing import List, Union import synapse @@ -131,6 +132,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -203,6 +207,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 068ad22b30..321a51cc6a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -889,6 +889,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity + self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors if self._account_validity.enabled: self._clock.call_later( @@ -1302,15 +1303,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) if not row: - raise ThreepidValidationError(400, "Unknown session_id") + if self._ignore_unknown_session_error: + # If we need to inhibit the error caused by an incorrect session ID, + # use None as placeholder values for the client secret and the + # validation timestamp. + # It shouldn't be an issue because they're both only checked after + # the token check, which should fail. And if it doesn't for some + # reason, the next check is on the client secret, which is NOT NULL, + # so we don't have to worry about the client secret matching by + # accident. + row = {"client_secret": None, "validated_at": None} + else: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] validated_at = row["validated_at"] - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" - ) - row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_token", @@ -1326,6 +1334,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expires = row["expires"] next_link = row["next_link"] + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + # If the session is already validated, no need to revalidate if validated_at: return next_link diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 840db66072..58f827d8d3 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import UserTypes +from synapse.api.errors import ThreepidValidationError from tests import unittest from tests.utils import setup_test_homeserver @@ -122,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase): ) res = yield self.store.is_support_user(SUPPORT_USER) self.assertTrue(res) + + @defer.inlineCallbacks + def test_3pid_inhibit_invalid_validation_session_error(self): + """Tests that enabling the configuration option to inhibit 3PID errors on + /requestToken also inhibits validation errors caused by an unknown session ID. + """ + + # Check that, with the config setting set to false (the default value), a + # validation error is caused by the unknown session ID. + try: + yield defer.ensureDeferred( + self.store.validate_threepid_session( + "fake_sid", "fake_client_secret", "fake_token", 0, + ) + ) + except ThreepidValidationError as e: + self.assertEquals(e.msg, "Unknown session_id", e) + + # Set the config setting to true. + self.store._ignore_unknown_session_error = True + + # Check that now the validation error is caused by the token not matching. + try: + yield defer.ensureDeferred( + self.store.validate_threepid_session( + "fake_sid", "fake_client_secret", "fake_token", 0, + ) + ) + except ThreepidValidationError as e: + self.assertEquals(e.msg, "Validation token not found or has expired", e) -- cgit 1.5.1 From 2df82ae451e03d76fae5381961dd6229d5796400 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Fri, 21 Aug 2020 15:07:56 +0100 Subject: Do not apply ratelimiting on joins to appservices (#8139) Add new method ratelimiter.can_requester_do_action and ensure that appservices are exempt from being ratelimited. Co-authored-by: Patrick Cloke Co-authored-by: Erik Johnston --- changelog.d/8139.bugfix | 1 + synapse/api/ratelimiting.py | 37 +++++++++++++++++++++ synapse/handlers/room_member.py | 14 ++++---- tests/api/test_ratelimiting.py | 73 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8139.bugfix (limited to 'tests') diff --git a/changelog.d/8139.bugfix b/changelog.d/8139.bugfix new file mode 100644 index 0000000000..21f65d87b7 --- /dev/null +++ b/changelog.d/8139.bugfix @@ -0,0 +1 @@ +Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index ec6b3a69a2..e62ae50ac2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.types import Requester from synapse.util import Clock @@ -43,6 +44,42 @@ class Ratelimiter(object): # * The rate_hz of this particular entry. This can vary per request self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] + def can_requester_do_action( + self, + requester: Requester, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: + """Can the requester perform the action? + + Args: + requester: The requester to key off when rate limiting. The user property + will be used. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Returns: + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero + """ + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + return self.can_do_action( + requester.user.to_string(), rate_hz, burst_count, update, _time_now_s + ) + def can_do_action( self, key: Any, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 31705cdbdb..0cd59bce3b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -459,9 +459,10 @@ class RoomMemberHandler(object): if is_host_in_room: time_now_s = self.clock.time() - allowed, time_allowed = self._join_rate_limiter_local.can_do_action( - requester.user.to_string(), - ) + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_local.can_requester_do_action(requester,) if not allowed: raise LimitExceededError( @@ -470,9 +471,10 @@ class RoomMemberHandler(object): else: time_now_s = self.clock.time() - allowed, time_allowed = self._join_rate_limiter_remote.can_do_action( - requester.user.to_string(), - ) + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_remote.can_requester_do_action(requester,) if not allowed: raise LimitExceededError( diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index d580e729c5..1e1f30d790 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,4 +1,6 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter +from synapse.appservice import ApplicationService +from synapse.types import create_requester from tests import unittest @@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase): self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) + def test_allowed_user_via_can_requester_do_action(self): + user_requester = create_requester("@user:example.com") + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=True, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=False, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + def test_allowed_via_ratelimit(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) -- cgit 1.5.1 From 393a811a41d51d7967f6d455017176a20eacd85c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 24 Aug 2020 18:06:04 +0100 Subject: Fix join ratelimiter breaking profile updates and idempotency (#8153) --- changelog.d/8153.bugfix | 1 + synapse/handlers/room_member.py | 46 +++++++++++--------- tests/rest/client/v1/test_rooms.py | 87 +++++++++++++++++++++++++++++++++++++- tests/rest/client/v1/utils.py | 10 +++-- 4 files changed, 119 insertions(+), 25 deletions(-) create mode 100644 changelog.d/8153.bugfix (limited to 'tests') diff --git a/changelog.d/8153.bugfix b/changelog.d/8153.bugfix new file mode 100644 index 0000000000..87a1f46ca1 --- /dev/null +++ b/changelog.d/8153.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0cd59bce3b..9fcabb22c7 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -210,24 +210,40 @@ class RoomMemberHandler(object): _, stream_id = await self.store.get_event_ordering(duplicate.event_id) return duplicate.event_id, stream_id - stream_id = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit, - ) - prev_state_ids = await context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + newly_joined = False if event.membership == Membership.JOIN: - # Only fire user_joined_room if the user has actually joined the - # room. Don't bother if the user is just changing their profile - # info. newly_joined = True if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN + + # Only rate-limit if the user actually joined the room, otherwise we'll end + # up blocking profile updates. if newly_joined: - await self._user_joined_room(target, room_id) + time_now_s = self.clock.time() + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_local.can_requester_do_action(requester) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now_s)) + ) + + stream_id = await self.event_creation_handler.handle_new_client_event( + requester, event, context, extra_users=[target], ratelimit=ratelimit, + ) + + if event.membership == Membership.JOIN and newly_joined: + # Only fire user_joined_room if the user has actually joined the + # room. Don't bother if the user is just changing their profile + # info. + await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) @@ -457,19 +473,7 @@ class RoomMemberHandler(object): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - if is_host_in_room: - time_now_s = self.clock.time() - ( - allowed, - time_allowed, - ) = self._join_rate_limiter_local.can_requester_do_action(requester,) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)) - ) - - else: + if not is_host_in_room: time_now_s = self.clock.time() ( allowed, diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ef6b775ed2..e74bddc1e5 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -28,7 +28,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account -from synapse.types import JsonDict, RoomAlias +from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string from tests import unittest @@ -675,6 +675,91 @@ class RoomMemberStateTestCase(RoomBase): self.assertEquals(json.loads(content), channel.json_body) +class RoomJoinRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}} + ) + def test_join_local_ratelimit(self): + """Tests that local joins are actually rate-limited.""" + for i in range(5): + self.helper.create_room_as(self.user_id) + + self.helper.create_room_as(self.user_id, expect_code=429) + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}} + ) + def test_join_local_ratelimit_profile_change(self): + """Tests that sending a profile update into all of the user's joined rooms isn't + rate-limited by the rate-limiter on joins.""" + + # Create and join more rooms than the rate-limiting config allows in a second. + room_ids = [ + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + ] + self.reactor.advance(1) + room_ids = room_ids + [ + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + ] + + # Create a profile for the user, since it hasn't been done on registration. + store = self.hs.get_datastore() + store.create_profile(UserID.from_string(self.user_id).localpart) + + # Update the display name for the user. + path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id + request, channel = self.make_request("PUT", path, {"displayname": "John Doe"}) + self.render(request) + self.assertEquals(channel.code, 200, channel.json_body) + + # Check that all the rooms have been sent a profile update into. + for room_id in room_ids: + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( + room_id, + self.user_id, + ) + + request, channel = self.make_request("GET", path) + self.render(request) + self.assertEquals(channel.code, 200) + + self.assertIn("displayname", channel.json_body) + self.assertEquals(channel.json_body["displayname"], "John Doe") + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}} + ) + def test_join_local_ratelimit_idempotent(self): + """Tests that the room join endpoints remain idempotent despite rate-limiting + on room joins.""" + room_id = self.helper.create_room_as(self.user_id) + + # Let's test both paths to be sure. + paths_to_test = [ + "/_matrix/client/r0/rooms/%s/join", + "/_matrix/client/r0/join/%s", + ] + + for path in paths_to_test: + # Make sure we send more requests than the rate-limiting config would allow + # if all of these requests ended up joining the user to a room. + for i in range(6): + request, channel = self.make_request("POST", path % room_id, {}) + self.render(request) + self.assertEquals(channel.code, 200) + + class RoomMessagesTestCase(RoomBase): """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """ diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 8933b560d2..e66c9a4c4c 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -39,7 +39,9 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator=None, is_public=True, tok=None): + def create_room_as( + self, room_creator=None, is_public=True, tok=None, expect_code=200, + ): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" @@ -54,9 +56,11 @@ class RestHelper(object): ) render(request, self.resource, self.hs.get_reactor()) - assert channel.result["code"] == b"200", channel.result + assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id - return channel.json_body["room_id"] + + if expect_code == 200: + return channel.json_body["room_id"] def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): self.change_membership( -- cgit 1.5.1 From 3f8f96be00104e1d1d42fde8e513985fc66201bf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Aug 2020 13:08:33 -0400 Subject: Fix flaky shadow-ban tests. (#8152) --- changelog.d/8152.feature | 1 + tests/rest/client/v1/test_rooms.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8152.feature (limited to 'tests') diff --git a/changelog.d/8152.feature b/changelog.d/8152.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8152.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e674eb90d7..286e0ccdcc 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -21,7 +21,7 @@ import json from urllib import parse as urlparse -from mock import Mock +from mock import Mock, patch import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership @@ -1976,6 +1976,8 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) +# To avoid the tests timing out don't add a delay to "annoy the requester". +@patch("random.randint", new=lambda a, b: 0) class ShadowBannedTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, -- cgit 1.5.1 From 420484a334a79b31e689bdcca2e57d9a23f7e3d4 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 24 Aug 2020 18:21:04 +0100 Subject: Allow capping a room's retention policy (#8104) --- changelog.d/8104.bugfix | 1 + docs/sample_config.yaml | 22 +++++---- synapse/config/server.py | 22 +++++---- synapse/events/validator.py | 59 ++--------------------- synapse/handlers/pagination.py | 36 +++++++++++--- tests/rest/client/test_retention.py | 94 ++++++++++++++++++++++++++----------- 6 files changed, 127 insertions(+), 107 deletions(-) create mode 100644 changelog.d/8104.bugfix (limited to 'tests') diff --git a/changelog.d/8104.bugfix b/changelog.d/8104.bugfix new file mode 100644 index 0000000000..e32e2996c4 --- /dev/null +++ b/changelog.d/8104.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.7.2 impacting message retention policies that would allow federated homeservers to dictate a retention period that's lower than the configured minimum allowed duration in the configuration file. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index f168853f67..3528d9e11f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -378,11 +378,10 @@ retention: # min_lifetime: 1d # max_lifetime: 1y - # Retention policy limits. If set, a user won't be able to send a - # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' - # that's not within this range. This is especially useful in closed federations, - # in which server admins can make sure every federating server applies the same - # rules. + # Retention policy limits. If set, and the state of a room contains a + # 'm.room.retention' event in its state which contains a 'min_lifetime' or a + # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy + # to these limits when running purge jobs. # #allowed_lifetime_min: 1d #allowed_lifetime_max: 1y @@ -408,12 +407,19 @@ retention: # (e.g. every 12h), but not want that purge to be performed by a job that's # iterating over every room it knows, which could be heavy on the server. # + # If any purge job is configured, it is strongly recommended to have at least + # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime' + # set, or one job without 'shortest_max_lifetime' and one job without + # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if + # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a + # room's policy to these values is done after the policies are retrieved from + # Synapse's database (which is done using the range specified in a purge job's + # configuration). + # #purge_jobs: - # - shortest_max_lifetime: 1d - # longest_max_lifetime: 3d + # - longest_max_lifetime: 3d # interval: 12h # - shortest_max_lifetime: 3d - # longest_max_lifetime: 1y # interval: 1d # Inhibits the /requestToken endpoints from returning an error that might leak diff --git a/synapse/config/server.py b/synapse/config/server.py index ed66f3eba1..526a90b26a 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -961,11 +961,10 @@ class ServerConfig(Config): # min_lifetime: 1d # max_lifetime: 1y - # Retention policy limits. If set, a user won't be able to send a - # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' - # that's not within this range. This is especially useful in closed federations, - # in which server admins can make sure every federating server applies the same - # rules. + # Retention policy limits. If set, and the state of a room contains a + # 'm.room.retention' event in its state which contains a 'min_lifetime' or a + # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy + # to these limits when running purge jobs. # #allowed_lifetime_min: 1d #allowed_lifetime_max: 1y @@ -991,12 +990,19 @@ class ServerConfig(Config): # (e.g. every 12h), but not want that purge to be performed by a job that's # iterating over every room it knows, which could be heavy on the server. # + # If any purge job is configured, it is strongly recommended to have at least + # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime' + # set, or one job without 'shortest_max_lifetime' and one job without + # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if + # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a + # room's policy to these values is done after the policies are retrieved from + # Synapse's database (which is done using the range specified in a purge job's + # configuration). + # #purge_jobs: - # - shortest_max_lifetime: 1d - # longest_max_lifetime: 3d + # - longest_max_lifetime: 3d # interval: 12h # - shortest_max_lifetime: 3d - # longest_max_lifetime: 1y # interval: 1d # Inhibits the /requestToken endpoints from returning an error that might leak diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 588d222f36..5ce3874fba 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -74,15 +74,14 @@ class EventValidator(object): ) if event.type == EventTypes.Retention: - self._validate_retention(event, config) + self._validate_retention(event) - def _validate_retention(self, event, config): + def _validate_retention(self, event): """Checks that an event that defines the retention policy for a room respects the - boundaries imposed by the server's administrator. + format enforced by the spec. Args: event (FrozenEvent): The event to validate. - config (Config): The homeserver's configuration. """ min_lifetime = event.content.get("min_lifetime") max_lifetime = event.content.get("max_lifetime") @@ -95,32 +94,6 @@ class EventValidator(object): errcode=Codes.BAD_JSON, ) - if ( - config.retention_allowed_lifetime_min is not None - and min_lifetime < config.retention_allowed_lifetime_min - ): - raise SynapseError( - code=400, - msg=( - "'min_lifetime' can't be lower than the minimum allowed" - " value enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - - if ( - config.retention_allowed_lifetime_max is not None - and min_lifetime > config.retention_allowed_lifetime_max - ): - raise SynapseError( - code=400, - msg=( - "'min_lifetime' can't be greater than the maximum allowed" - " value enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - if max_lifetime is not None: if not isinstance(max_lifetime, int): raise SynapseError( @@ -129,32 +102,6 @@ class EventValidator(object): errcode=Codes.BAD_JSON, ) - if ( - config.retention_allowed_lifetime_min is not None - and max_lifetime < config.retention_allowed_lifetime_min - ): - raise SynapseError( - code=400, - msg=( - "'max_lifetime' can't be lower than the minimum allowed value" - " enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - - if ( - config.retention_allowed_lifetime_max is not None - and max_lifetime > config.retention_allowed_lifetime_max - ): - raise SynapseError( - code=400, - msg=( - "'max_lifetime' can't be greater than the maximum allowed" - " value enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - if ( min_lifetime is not None and max_lifetime is not None diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 487420bb5d..ac3418d69d 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -82,6 +82,9 @@ class PaginationHandler(object): self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min + self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max + if hs.config.retention_enabled: # Run the purge jobs described in the configuration file. for job in hs.config.retention_purge_jobs: @@ -111,7 +114,7 @@ class PaginationHandler(object): the range to handle (inclusive). If None, it means that the range has no upper limit. """ - # We want the storage layer to to include rooms with no retention policy in its + # We want the storage layer to include rooms with no retention policy in its # return value only if a default retention policy is defined in the server's # configuration and that policy's 'max_lifetime' is either lower (or equal) than # max_ms or higher than min_ms (or both). @@ -152,13 +155,32 @@ class PaginationHandler(object): ) continue - max_lifetime = retention_policy["max_lifetime"] + # If max_lifetime is None, it means that the room has no retention policy. + # Given we only retrieve such rooms when there's a default retention policy + # defined in the server's configuration, we can safely assume that's the + # case and use it for this room. + max_lifetime = ( + retention_policy["max_lifetime"] or self._retention_default_max_lifetime + ) - if max_lifetime is None: - # If max_lifetime is None, it means that include_null equals True, - # therefore we can safely assume that there is a default policy defined - # in the server's configuration. - max_lifetime = self._retention_default_max_lifetime + # Cap the effective max_lifetime to be within the range allowed in the + # config. + # We do this in two steps: + # 1. Make sure it's higher or equal to the minimum allowed value, and if + # it's not replace it with that value. This is because the server + # operator can be required to not delete information before a given + # time, e.g. to comply with freedom of information laws. + # 2. Make sure the resulting value is lower or equal to the maximum allowed + # value, and if it's not replace it with that value. This is because the + # server operator can be required to delete any data after a specific + # amount of time. + if self._retention_allowed_lifetime_min is not None: + max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime) + + if self._retention_allowed_lifetime_max is not None: + max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max) + + logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime) # Figure out what token we should start purging at. ts = self.clock.time_msec() - max_lifetime diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 0b191d13c6..d4e7fa1293 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -45,50 +45,63 @@ class RetentionTestCase(unittest.HomeserverTestCase): } self.hs = self.setup_test_homeserver(config=config) + return self.hs def prepare(self, reactor, clock, homeserver): self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - def test_retention_state_event(self): - """Tests that the server configuration can limit the values a user can set to the - room's retention policy. + self.store = self.hs.get_datastore() + self.serializer = self.hs.get_event_client_serializer() + self.clock = self.hs.get_clock() + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. """ room_id = self.helper.create_room_as(self.user_id, tok=self.token) + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": one_day_ms * 4}, + body={"max_lifetime": lifetime}, tok=self.token, - expect_code=400, ) + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_with_state_event_outside_allowed(self): + """Tests that the server configuration can override the policy for a room when + running the purge jobs. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set a max_lifetime higher than the maximum allowed value. self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": one_hour_ms}, + body={"max_lifetime": one_day_ms * 4}, tok=self.token, - expect_code=400, ) - def test_retention_event_purged_with_state_event(self): - """Tests that expired events are correctly purged when the room's retention policy - is defined by a state event. - """ - room_id = self.helper.create_room_as(self.user_id, tok=self.token) + # Check that the event is purged after waiting for the maximum allowed duration + # instead of the one specified in the room's policy. + self._test_retention_event_purged(room_id, one_day_ms * 1.5) - # Set the room's retention period to 2 days. - lifetime = one_day_ms * 2 + # Set a max_lifetime lower than the minimum allowed value. self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": lifetime}, + body={"max_lifetime": one_hour_ms}, tok=self.token, ) - self._test_retention_event_purged(room_id, one_day_ms * 1.5) + # Check that the event is purged after waiting for the minimum allowed duration + # instead of the one specified in the room's policy. + self._test_retention_event_purged(room_id, one_day_ms * 0.5) def test_retention_event_purged_without_state_event(self): """Tests that expired events are correctly purged when the room's retention policy @@ -140,7 +153,27 @@ class RetentionTestCase(unittest.HomeserverTestCase): # That event should be the second, not outdated event. self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) - def _test_retention_event_purged(self, room_id, increment): + def _test_retention_event_purged(self, room_id: str, increment: float): + """Run the following test scenario to test the message retention policy support: + + 1. Send event 1 + 2. Increment time by `increment` + 3. Send event 2 + 4. Increment time by `increment` + 5. Check that event 1 has been purged + 6. Check that event 2 has not been purged + 7. Check that state events that were sent before event 1 aren't purged. + The main reason for sending a second event is because currently Synapse won't + purge the latest message in a room because it would otherwise result in a lack of + forward extremities for this room. It's also a good thing to ensure the purge jobs + aren't too greedy and purge messages they shouldn't. + + Args: + room_id: The ID of the room to test retention in. + increment: The number of milliseconds to advance the clock each time. Must be + defined so that events in the room aren't purged if they are `increment` + old but are purged if they are `increment * 2` old. + """ # Get the create event to, later, check that we can still access it. message_handler = self.hs.get_message_handler() create_event = self.get_success( @@ -156,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): expired_event_id = resp.get("event_id") # Check that we can retrieve the event. - expired_event = self.get_event(room_id, expired_event_id) + expired_event = self.get_event(expired_event_id) self.assertEqual( expired_event.get("content", {}).get("body"), "1", expired_event ) @@ -174,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase): # one should still be kept. self.reactor.advance(increment / 1000) - # Check that the event has been purged from the database. - self.get_event(room_id, expired_event_id, expected_code=404) + # Check that the first event has been purged from the database, i.e. that we + # can't retrieve it anymore, because it has expired. + self.get_event(expired_event_id, expect_none=True) - # Check that the event that hasn't been purged can still be retrieved. - valid_event = self.get_event(room_id, valid_event_id) + # Check that the event that hasn't expired can still be retrieved. + valid_event = self.get_event(valid_event_id) self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) # Check that we can still access state events that were sent before the event that # has been purged. self.get_event(room_id, create_event.event_id) - def get_event(self, room_id, event_id, expected_code=200): - url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + def get_event(self, event_id, expect_none=False): + event = self.get_success(self.store.get_event(event_id, allow_none=True)) - request, channel = self.make_request("GET", url, access_token=self.token) - self.render(request) + if expect_none: + self.assertIsNone(event) + return {} - self.assertEqual(channel.code, expected_code, channel.result) + self.assertIsNotNone(event) - return channel.json_body + time_now = self.clock.time_msec() + serialized = self.get_success(self.serializer.serialize_event(event, time_now)) + + return serialized class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From cbd8d83da7d24d7434c749c4c6cfece0c507b0b9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Aug 2020 13:58:56 -0400 Subject: Stop shadow-banned users from sending non-member events. (#8142) --- changelog.d/8142.feature | 1 + synapse/handlers/directory.py | 6 ++ synapse/handlers/message.py | 10 +++ synapse/handlers/room.py | 19 +++++- synapse/rest/client/v1/room.py | 74 +++++++++++++--------- synapse/rest/client/v2_alpha/relations.py | 18 ++++-- .../client/v2_alpha/room_upgrade_rest_servlet.py | 14 ++-- tests/rest/client/v1/test_rooms.py | 55 +++++++++++++++- 8 files changed, 155 insertions(+), 42 deletions(-) create mode 100644 changelog.d/8142.feature (limited to 'tests') diff --git a/changelog.d/8142.feature b/changelog.d/8142.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8142.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 79a2df6201..46826eb784 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -23,6 +23,7 @@ from synapse.api.errors import ( CodeMessageException, Codes, NotFoundError, + ShadowBanError, StoreError, SynapseError, ) @@ -199,6 +200,8 @@ class DirectoryHandler(BaseHandler): try: await self._update_canonical_alias(requester, user_id, room_id, room_alias) + except ShadowBanError as e: + logger.info("Failed to update alias events due to shadow-ban: %s", e) except AuthError as e: logger.info("Failed to update alias events: %s", e) @@ -292,6 +295,9 @@ class DirectoryHandler(BaseHandler): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. + + Raises: + ShadowBanError if the requester has been shadow-banned. """ alias_event = await self.state.get_current_state( room_id, EventTypes.CanonicalAlias, "" diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c955a86be0..593c0cc6f1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -34,6 +35,7 @@ from synapse.api.errors import ( Codes, ConsentNotGivenError, NotFoundError, + ShadowBanError, SynapseError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions @@ -716,12 +718,20 @@ class EventCreationHandler(object): event_dict: dict, ratelimit: bool = True, txn_id: Optional[str] = None, + ignore_shadow_ban: bool = False, ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. See self.create_event and self.send_nonmember_event. + + Raises: + ShadowBanError if the requester has been shadow-banned. """ + if not ignore_shadow_ban and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() # We limit the number of concurrent event sends in a room so that we # don't fork the DAG too much. If we don't limit then we can end up in diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0fc71475c3..e4788ef86b 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -136,6 +136,9 @@ class RoomCreationHandler(BaseHandler): Returns: the new room id + + Raises: + ShadowBanError if the requester is shadow-banned. """ await self.ratelimit(requester) @@ -171,6 +174,15 @@ class RoomCreationHandler(BaseHandler): async def _upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ): + """ + Args: + requester: the user requesting the upgrade + old_room_id: the id of the room to be replaced + new_versions: the version to upgrade the room to + + Raises: + ShadowBanError if the requester is shadow-banned. + """ user_id = requester.user.to_string() # start by allocating a new room id @@ -257,6 +269,9 @@ class RoomCreationHandler(BaseHandler): old_room_id: the id of the room to be replaced new_room_id: the id of the replacement room old_room_state: the state map for the old room + + Raises: + ShadowBanError if the requester is shadow-banned. """ old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, "")) @@ -829,11 +844,13 @@ class RoomCreationHandler(BaseHandler): async def send(etype: str, content: JsonDict, **kwargs) -> int: event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) + # Allow these events to be sent even if the user is shadow-banned to + # allow the room creation to complete. ( _, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( - creator, event, ratelimit=False + creator, event, ratelimit=False, ignore_shadow_ban=True, ) return last_stream_id diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index a9dd3a6aec..11da8bc037 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -201,8 +201,8 @@ class RoomStateEventRestServlet(TransactionRestServlet): if state_key is not None: event_dict["state_key"] = state_key - if event_type == EventTypes.Member: - try: + try: + if event_type == EventTypes.Member: membership = content.get("membership", None) event_id, _ = await self.room_member_handler.update_membership( requester, @@ -211,16 +211,16 @@ class RoomStateEventRestServlet(TransactionRestServlet): action=membership, content=content, ) - except ShadowBanError: - event_id = "$" + random_string(43) - else: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) - event_id = event.event_id + else: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) set_tag("event_id", event_id) ret = {"event_id": event_id} @@ -253,12 +253,19 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - set_tag("event_id", event.event_id) - return 200, {"event_id": event.event_id} + set_tag("event_id", event_id) + return 200, {"event_id": event_id} def on_GET(self, request, room_id, event_type, txn_id): return 200, "Not implemented" @@ -799,20 +806,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - set_tag("event_id", event.event_id) - return 200, {"event_id": event.event_id} + set_tag("event_id", event_id) + return 200, {"event_id": event_id} def on_PUT(self, request, room_id, event_id, txn_id): set_tag("txn_id", txn_id) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 89002ffbff..e29f49f7f5 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -22,7 +22,7 @@ any time to reflect changes in the MSC. import logging from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import ShadowBanError, SynapseError from synapse.http.servlet import ( RestServlet, parse_integer, @@ -35,6 +35,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) +from synapse.util.stringutils import random_string from ._base import client_patterns @@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict=event_dict, txn_id=txn_id - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - return 200, {"event_id": event.event_id} + return 200, {"event_id": event_id} class RelationPaginationServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index f357015a70..39a5518614 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -15,13 +15,14 @@ import logging -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.util import stringutils from ._base import client_patterns @@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ("new_version",)) - new_version = content["new_version"] new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) if new_version is None: @@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet): Codes.UNSUPPORTED_ROOM_VERSION, ) - new_room_id = await self._room_creation_handler.upgrade_room( - requester, room_id, new_version - ) + try: + new_room_id = await self._room_creation_handler.upgrade_room( + requester, room_id, new_version + ) + except ShadowBanError: + # Generate a random room ID. + new_room_id = stringutils.random_string(18) ret = {"replacement_room": new_room_id} diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 286e0ccdcc..60fef13e9f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -27,7 +27,7 @@ import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import account +from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet from synapse.types import JsonDict, RoomAlias from synapse.util.stringutils import random_string @@ -1984,6 +1984,7 @@ class ShadowBannedTestCase(unittest.HomeserverTestCase): directory.register_servlets, login.register_servlets, room.register_servlets, + room_upgrade_rest_servlet.register_servlets, ] def prepare(self, reactor, clock, homeserver): @@ -2076,3 +2077,55 @@ class ShadowBannedTestCase(unittest.HomeserverTestCase): # Both users should be in the room. users = self.get_success(self.store.get_users_in_room(room_id)) self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) + + def test_message(self): + """Messages from shadow-banned users don't actually get sent.""" + + room_id = self.helper.create_room_as( + self.other_user_id, tok=self.other_access_token + ) + + # The user should be in the room. + self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token) + + # Sending a message should complete successfully. + result = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "with right label"}, + tok=self.banned_access_token, + ) + self.assertIn("event_id", result) + event_id = result["event_id"] + + latest_events = self.get_success( + self.store.get_latest_event_ids_in_room(room_id) + ) + self.assertNotIn(event_id, latest_events) + + def test_upgrade(self): + """A room upgrade should fail, but look like it succeeded.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), + {"new_version": "6"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + # A new room_id should be returned. + self.assertIn("replacement_room", channel.json_body) + + new_room_id = channel.json_body["replacement_room"] + + # It doesn't really matter what API we use here, we just want to assert + # that the room doesn't exist. + summary = self.get_success(self.store.get_room_summary(new_room_id)) + # The summary should be empty since the room doesn't exist. + self.assertEqual(summary, {}) -- cgit 1.5.1 From eba98fb024af4c84901a7ba01940ffb3c50950c8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 25 Aug 2020 17:32:30 +0100 Subject: Add functions to `MultiWriterIdGen` used by events stream (#8164) --- changelog.d/8164.misc | 1 + synapse/storage/util/id_generators.py | 103 +++++++++++++++++++++++++++++++++- synapse/storage/util/sequence.py | 8 ++- tests/storage/test_id_generators.py | 36 ++++++++++++ 4 files changed, 145 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8164.misc (limited to 'tests') diff --git a/changelog.d/8164.misc b/changelog.d/8164.misc new file mode 100644 index 0000000000..55bc079cdb --- /dev/null +++ b/changelog.d/8164.misc @@ -0,0 +1 @@ +Add functions to `MultiWriterIdGen` used by events stream. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index ddb5c8c60c..5b07847773 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -14,9 +14,10 @@ # limitations under the License. import contextlib +import heapq import threading from collections import deque -from typing import Dict, Set +from typing import Dict, List, Set from typing_extensions import Deque @@ -210,6 +211,23 @@ class MultiWriterIdGenerator: # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + # We track the max position where we know everything before has been + # persisted. This is done by a) looking at the min across all instances + # and b) noting that if we have seen a run of persisted positions + # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7). + # + # Note: There is no guarentee that the IDs generated by the sequence + # will be gapless; gaps can form when e.g. a transaction was rolled + # back. This means that sometimes we won't be able to skip forward the + # position even though everything has been persisted. However, since + # gaps should be relatively rare it's still worth doing the book keeping + # that allows us to skip forwards when there are gapless runs of + # positions. + self._persisted_upto_position = ( + min(self._current_positions.values()) if self._current_positions else 0 + ) + self._known_persisted_positions = [] # type: List[int] + self._sequence_gen = PostgresSequenceGenerator(sequence_name) def _load_current_ids( @@ -234,9 +252,12 @@ class MultiWriterIdGenerator: return current_positions - def _load_next_id_txn(self, txn): + def _load_next_id_txn(self, txn) -> int: return self._sequence_gen.get_next_id_txn(txn) + def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: + return self._sequence_gen.get_next_mult_txn(txn, n) + async def get_next(self): """ Usage: @@ -262,6 +283,34 @@ class MultiWriterIdGenerator: return manager() + async def get_next_mult(self, n: int): + """ + Usage: + with await stream_id_gen.get_next_mult(5) as stream_ids: + # ... persist events ... + """ + next_ids = await self._db.runInteraction( + "_load_next_mult_id", self._load_next_mult_id_txn, n + ) + + # Assert the fetched ID is actually greater than any ID we've already + # seen. If not, then the sequence and table have got out of sync + # somehow. + assert max(self.get_positions().values(), default=0) < min(next_ids) + + with self._lock: + self._unfinished_ids.update(next_ids) + + @contextlib.contextmanager + def manager(): + try: + yield next_ids + finally: + for i in next_ids: + self._mark_id_as_finished(i) + + return manager() + def get_next_txn(self, txn: LoggingTransaction): """ Usage: @@ -326,3 +375,53 @@ class MultiWriterIdGenerator: self._current_positions[instance_name] = max( new_id, self._current_positions.get(instance_name, 0) ) + + self._add_persisted_position(new_id) + + def get_persisted_upto_position(self) -> int: + """Get the max position where all previous positions have been + persisted. + + Note: In the worst case scenario this will be equal to the minimum + position across writers. This means that the returned position here can + lag if one writer doesn't write very often. + """ + + with self._lock: + return self._persisted_upto_position + + def _add_persisted_position(self, new_id: int): + """Record that we have persisted a position. + + This is used to keep the `_current_positions` up to date. + """ + + # We require that the lock is locked by caller + assert self._lock.locked() + + heapq.heappush(self._known_persisted_positions, new_id) + + # We move the current min position up if the minimum current positions + # of all instances is higher (since by definition all positions less + # that that have been persisted). + min_curr = min(self._current_positions.values()) + self._persisted_upto_position = max(min_curr, self._persisted_upto_position) + + # We now iterate through the seen positions, discarding those that are + # less than the current min positions, and incrementing the min position + # if its exactly one greater. + # + # This is also where we discard items from `_known_persisted_positions` + # (to ensure the list doesn't infinitely grow). + while self._known_persisted_positions: + if self._known_persisted_positions[0] <= self._persisted_upto_position: + heapq.heappop(self._known_persisted_positions) + elif ( + self._known_persisted_positions[0] == self._persisted_upto_position + 1 + ): + heapq.heappop(self._known_persisted_positions) + self._persisted_upto_position += 1 + else: + # There was a gap in seen positions, so there is nothing more to + # do. + break diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 63dfea4220..ffc1894748 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -14,7 +14,7 @@ # limitations under the License. import abc import threading -from typing import Callable, Optional +from typing import Callable, List, Optional from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.types import Cursor @@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator): txn.execute("SELECT nextval(?)", (self._sequence_name,)) return txn.fetchone()[0] + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + txn.execute( + "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n) + ) + return [i for (i,) in txn] + GetFirstCallbackType = Callable[[Cursor], int] diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 7a05194653..9b9a183e7f 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -182,3 +182,39 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) + + def test_get_persisted_upto_position(self): + """Test that `get_persisted_upto_position` correctly tracks updates to + positions. + """ + + self._insert_rows("first", 3) + self._insert_rows("second", 5) + + id_gen = self._create_id_generator("first") + + # Min is 3 and there is a gap between 5, so we expect it to be 3. + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + # We advance "first" straight to 6. Min is now 5 but there is no gap so + # we expect it to be 6 + id_gen.advance("first", 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 6) + + # No gap, so we expect 7. + id_gen.advance("second", 7) + self.assertEqual(id_gen.get_persisted_upto_position(), 7) + + # We haven't seen 8 yet, so we expect 7 still. + id_gen.advance("second", 9) + self.assertEqual(id_gen.get_persisted_upto_position(), 7) + + # Now that we've seen 7, 8 and 9 we can got straight to 9. + id_gen.advance("first", 8) + self.assertEqual(id_gen.get_persisted_upto_position(), 9) + + # Jump forward with gaps. The minimum is 11, even though we haven't seen + # 10 we know that everything before 11 must be persisted. + id_gen.advance("first", 11) + id_gen.advance("second", 15) + self.assertEqual(id_gen.get_persisted_upto_position(), 11) -- cgit 1.5.1